博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
实现knn算法,完整程序
阅读量:6378 次
发布时间:2019-06-23

本文共 3054 字,大约阅读时间需要 10 分钟。

1主要步骤:

从文件中读初始数据------计算目标点到已归类点的距离------根据距离最近原则判断目标点归属于哪一类别

import csvimport randomimport mathimport operatordef loadDataset(filename, split, trainingSet=[], testSet=[]):    with open(filename, 'rb') as csvfile:        lines = csv.reader(csvfile)        dataset = list(lines)        for x in range(len(dataset)-1):            for y in range(4):                dataset[x][y] = float(dataset[x][y])            if random.random() < split:            #if random.randrange(len(trainingSet)) < split:                trainingSet.append(dataset[x])            else:                testSet.append(dataset[x])def euclideanDistance(instance1, instance2, length):    distance = 0    for x in range(length):        distance += pow((instance1[x]-instance2[x]), 2)    return math.sqrt(distance)def getNeighbors(trainingSet, testInstance, k):    distances = []    length = len(testInstance)-1    for x in range(len(trainingSet)):        dist = euclideanDistance(testInstance, trainingSet[x], length)        distances.append((trainingSet[x], dist))    distances.sort(key=operator.itemgetter(1))    neighbors = []    for x in range(k):        neighbors.append(distances[x][0])        return neighborsdef getResponse(neighbors):    classVotes = {}    for x in range(len(neighbors)):        response = neighbors[x][-1]        if response in classVotes:            classVotes[response] += 1        else:            classVotes[response] = 1    sortedVotes = sorted(classVotes.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedVotes[0][0]def getAccuracy(testSet, predictions):    correct = 0    for x in range(len(testSet)):        #print 'test'        # test = testSet[x][-1]        # print test        # print 'pre'        # pre = predictions[x]        # print pre    print ('test: ' + repr(testSet[x][-1])) repr(testSet[x][-1])    print ('pre: ' + repr(predictions[x]))    # if testSet[z][-1] == predictions[z]:    #     correct += 1    return (correct/float(len(testSet)))*100.0def main():    #prepare data    """    :rtype: object    """    trainingSet = []    testSet = []    split = 0.70    loadDataset(r'/home/zhoumiao/ML/02KNearestNeighbor/irisdata.txt', split, trainingSet, testSet)    print 'Train set: ' + repr(len(trainingSet))    print 'Test set: ' + repr(len(testSet))    #generate predictions    predictions = []    k = 3    correct = []    for x in range(len(testSet)):        neighbors = getNeighbors(trainingSet, testSet[x], k)        result = getResponse(neighbors)        predictions.append(result)        #print ('test: ' + repr(testSet))        print ('predictions: ' + repr(predictions))        print ('>predicted=' + repr(result) + ', actual=' + repr(testSet[x][-1]))        if result == testSet[x][-1]:            correct.append(x)            # print "len:"            # print len(testSet)            # print "correct:"            # print len(correct)    accuracy = (len(correct)/float(len(testSet)))*100.0    print('Accuracy: ' + repr(accuracy) + '%')if __name__ == '__main__':main()

转载于:https://blog.51cto.com/13831593/2173693

你可能感兴趣的文章
Linux内核中__init, __initdata, __initfunc(), asmlinkage, ENTRY(), FASTCALL()等作用
查看>>
leetcode -- Two Sum
查看>>
Windows多线程
查看>>
C语言局部变量和全局变量问题汇总
查看>>
android 下的网络图片加载
查看>>
Paip.语义分析----情绪情感词汇表总结
查看>>
Linux下软件安装,卸载,管理
查看>>
View Programming Guide for iOS_读书笔记[正在更新……]
查看>>
排查VMWare虚拟机的性能问题
查看>>
yum安装Apache Web Server后各个文件存放位置
查看>>
Jquery EasyUI的添加,修改,删除,查询等基本操作介绍
查看>>
Android于JNI调用列出的程序
查看>>
CSS3-border-radius 属性
查看>>
解决Activity启动黑屏和设置android:windowIsTranslucent不兼容activity切换动画的问题
查看>>
C#开发SQLServer的Geometry和Geography存储
查看>>
EBS R12.2应用层关闭脚本的执行过程
查看>>
js:深闭包(范围:上)
查看>>
使用POI导入小数变成浮点数异常
查看>>
Logistic Regression的几个变种
查看>>
司机福利!Uber即将可以自己选目的地接单啦!
查看>>