import random from copy import deepcopy from time import time
  import numpy as np from numpy.linalg import norm
  from collections import Counter import time
  class Node(object):     """         结点对象     """     def __init__(self, item = None, label = None, dim = None, parent = None, left_child = None, right_child = None):         """             item: 结点的值(样本信息)             label: 结点的标签             dim: 结点的切分的维度(特征)         """         self.item = item         self.label = label         self.dim = dim         self.parent = parent         self.left_child = left_child         self.right_child = right_child               class KDTree(object):     def __init__(self, aList, labelList):         self.__length = 0          self.__root = self.__create(aList,labelList)                def __create(self, aList, labelList, parentNode = None):         '''             创建kd树             aList: 需要传入一个类数组对象(行数表示样本数,列数表示特征数)             labellist: 样本的标签             parentNode: 父结点             return: 根结点         '''         dataArray = np.array(aList)         m, n = dataArray.shape         labelArray = np.array(labelList).reshape(m, 1)                  if m == 0:             return None                                    var_list = [np.var(dataArray[:, col]) for col in range(n)]                           if var_list:             max_index = var_list.index(max(var_list))                                       max_feat_ind_list = dataArray[:, max_index].argsort()             mid_item_index = max_feat_ind_list[m // 2]                          if m == 1:                 self.__length += 1                 return Node(                     dim = max_index,                      label = labelArray[mid_item_index],                      item = dataArray[mid_item_index],                      parent = parentNode,                     left_child = None,                     right_child = None                 )                                       node = Node(                     dim = max_index,                      label = labelArray[mid_item_index],                      item = dataArray[mid_item_index],                      parent = parentNode,                     left_child = None,                     right_child = None                 )                          left_tree = dataArray[max_feat_ind_list[:m // 2:]]              left_label = labelArray[max_feat_ind_list[:m // 2]]              left_child = self.__create(left_tree, left_label, node)                                       if m == 2:                 right_child = None             else:                                  right_tree = dataArray[max_feat_ind_list[m // 2 + 1:]]                                  right_label = labelArray[max_feat_ind_list[m // 2 + 1:]]                 right_child = self.__create(right_tree, right_label, node)                           node.left_child=left_child         node.right_child=right_child         self.__length += 1         return node           @property     def length(self):         return self.__length
      @property     def root(self):         return self.__root          def transfer_dict(self, node):         '''         查看kd树结构         node:需要传入根结点对象         return: 字典嵌套格式的kd树,字典的key是self.item,其余项作为key的值,类似下面格式         {(1,2,3):{                 'label':1,                 'dim':0,                 'left_child':{(2,3,4):{                                      'label':1,                                      'dim':1,                                      'left_child':None,                                      'right_child':None                                     },                 'right_child':{(4,5,6):{                                         'label':1,                                         'dim':1,                                         'left_child':None,                                         'right_child':None                                         }                 }         '''                  if node == None:             return None                  kd_dict = {}         kd_dict[tuple(node.item)] = {}           kd_dict[tuple(node.item)]['label'] = node.label[0]         kd_dict[tuple(node.item)]['dim'] = node.dim         kd_dict[tuple(node.item)]['parent'] = tuple(node.parent.item) if node.parent else None         kd_dict[tuple(node.item)]['left_child'] = self.transfer_dict(node.left_child)         kd_dict[tuple(node.item)]['right_child'] = self.transfer_dict(node.right_child)         return kd_dict          def transfer_list(self,node, kdList=[]):         '''         将kd树转化为列表嵌套字典的嵌套字典的列表输出         :param node: 需要传入根结点         :return: 返回嵌套字典的列表         '''         if node == None:             return None         element_dict = {}         element_dict['item'] = tuple(node.item)         element_dict['label'] = node.label[0]         element_dict['dim'] = node.dim         element_dict['parent'] = tuple(node.parent.item) if node.parent else None         element_dict['left_child'] = tuple(node.left_child.item) if node.left_child else None         element_dict['right_child'] = tuple(node.right_child.item) if node.right_child else None         kdList.append(element_dict)         self.transfer_list(node.left_child, kdList)         self.transfer_list(node.right_child, kdList)         return kdList          def _find_nearest_neighbour(self, item):         '''         找最近邻点         :param item:需要预测的新样本         :return: 距离最近的样本点         '''         itemArray = np.array(item)         if self.length == 0:               return None                  node = self.__root         if self.length == 1:              return node         while True:             cur_dim = node.dim             if item[cur_dim] == node.item[cur_dim]:                 return node             elif item[cur_dim] < node.item[cur_dim]:                   if node.left_child == None:                       return node                 node = node.left_child             else:                 if node.right_child == None:                       return node                 node = node.right_child          def knn_algo(self, item, k = 1):         '''             找到距离测试样本最近的前k个样本             :param item: 测试样本             :param k: knn算法参数,定义需要参考的最近点数量,一般为1-5             :return: 返回前k个样本的最大分类标签         '''         if self.length <= k:             label_dict = {}                          for element in self.transfer_list(self.root):                 if element['label'] in label_dict:                     label_dict[element['label']] += 1                 else:                     label_dict[element['label']] = 1             sorted_label = sorted(label_dict.items(), key=lambda item:item[1],reverse=True)               return sorted_label[0][0]         item = np.array(item)         node = self._find_nearest_neighbour(item)           if node == None:               return None         print('靠近点%s最近的叶结点为:%s'%(item, node.item))         node_list = []         distance = np.sqrt(sum((item-node.item)**2))           least_dis = distance                  node_list.append([distance, tuple(node.item), node.label[0]])  
                   if node.left_child != None:             left_child = node.left_child             left_dis = np.sqrt(sum((item-left_child.item)**2))             if k > len(node_list) or least_dis < least_dis:                 node_list.append([left_dis, tuple(left_child.item), left_child.label[0]])                 node_list.sort()                   least_dis = node_list[-1][0] if k >= len(node_list) else node_list[k-1][0]                  while True:             if node == self.root:                   break             parent = node.parent                          par_dis = np.sqrt(sum((item-parent.item)**2))             if k >len(node_list) or par_dis < least_dis:                   node_list.append([par_dis, tuple(parent.item) , parent.label[0]])                 node_list.sort()                   least_dis = node_list[-1][0] if k >= len(node_list) else node_list[k - 1][0]
                           if k >len(node_list) or abs(item[parent.dim] - parent.item[parent.dim]) < least_dis :                                    other_child = parent.left_child if parent.left_child != node else parent.right_child                                    if other_child != None:                     if item[parent.dim] - parent.item[parent.dim] <= 0:                         self.left_search(item,other_child,node_list,k)                     else:                         self.right_search(item,other_child,node_list,k)  
              node = parent                    label_dict = {}         node_list = node_list[:k]                  for element in node_list:             if element[2] in label_dict:                 label_dict[element[2]] += 1             else:                 label_dict[element[2]] = 1         sorted_label = sorted(label_dict.items(), key=lambda item:item[1], reverse=True)           return sorted_label[0][0],node_list
      def left_search(self, item, node, nodeList, k):         '''         按左中右顺序遍历子树结点,返回结点列表         :param node: 子树结点         :param item: 传入的测试样本         :param nodeList: 结点列表         :param k: 搜索比较的结点数量         :return: 结点列表         '''         nodeList.sort()           least_dis = nodeList[-1][0] if k >= len(nodeList) else nodeList[k - 1][0]         if node.left_child == None and node.right_child == None:               dis = np.sqrt(sum((item - node.item) ** 2))             if k > len(nodeList) or dis < least_dis:                 nodeList.append([dis, tuple(node.item), node.label[0]])             return         self.left_search(item, node.left_child, nodeList, k)                  nodeList.sort()           least_dis = nodeList[-1][0] if k >= len(nodeList) else nodeList[k - 1][0]                  dis = np.sqrt(sum((item-node.item)**2))         if k > len(nodeList) or dis < least_dis:             nodeList.append([dis, tuple(node.item), node.label[0]])                  if k > len(nodeList) or abs(item[node.dim] - node.item[node.dim]) < least_dis:              if node.right_child != None:                 self.left_search(item, node.right_child, nodeList, k)
          return nodeList
      def right_search(self,item, node, nodeList, k):         '''         按右根左顺序遍历子树结点         :param item: 测试的样本点         :param node: 子树结点         :param nodeList: 结点列表         :param k: 搜索比较的结点数量         :return: 结点列表         '''         nodeList.sort()           least_dis = nodeList[-1][0] if k >= len(nodeList) else nodeList[k - 1][0]         if node.left_child == None and node.right_child == None:               dis = np.sqrt(sum((item - node.item) ** 2))             if k > len(nodeList) or dis < least_dis:                 nodeList.append([dis, tuple(node.item), node.label[0]])             return         if node.right_child != None:             self.right_search(item, node.right_child, nodeList, k)
          nodeList.sort()           least_dis = nodeList[-1][0] if k >= len(nodeList) else nodeList[k - 1][0]                  dis = np.sqrt(sum((item - node.item) ** 2))         if k > len(nodeList) or dis < least_dis:             nodeList.append([dis, tuple(node.item), node.label[0]])                  if k > len(nodeList) or abs(item[node.dim] - node.item[node.dim]) < least_dis:              self.right_search(item, node.left_child, nodeList, k)
          return nodeList
 
  if __name__ == '__main__':     t1 = time.time()          dataArray = np.random.randint(0,20,size=(10000,2))               label = np.random.randint(0,3,size=(10000,1))          kd_tree = KDTree(dataArray,label)                                             t2 = time.time()     label, node_list = kd_tree.knn_algo([12,7],k=5)     print('点%s的最接近的前k个点为:%s'%([12,7], node_list))     print('点%s的标签:%s'%([12,7],label))     t3 = time.time()     print('创建树耗时:',t2-t1)     print('搜索前k个最近邻点耗时:',t3-t2)
 
  |