您好,登錄后才能下訂單哦!
1.簡介
K均值聚類算法是先隨機選取K個對象作為初始的聚類中心。然后計算每個對象與各個種子聚類中心之間的距離,把每個對象分配給距離它最近的聚類中心。聚類中心以及分配給它們的對象就代表一個聚類。一旦全部對象都被分配了,每個聚類的聚類中心會根據聚類中現有的對象被重新計算。這個過程將不斷重復直到滿足某個終止條件。終止條件可以是沒有(或最小數目)對象被重新分配給不同的聚類,沒有(或最小數目)聚類中心再發生變化,誤差平方和局部最小。
2.什么是聚類
聚類是一個將數據集中在某些方面相似的數據成員進行分類組織的過程,聚類就是一種發現這種內在結構的技術,聚類技術經常被稱為無監督學習。
3.什么是k均值聚類
k均值聚類是最著名的劃分聚類算法,由于簡潔和效率使得他成為所有聚類算法中最廣泛使用的。給定一個數據點集合和需要的聚類數目k,k由用戶指定,k均值算法根據某個距離函數反復把數據分入k個聚類中。
4.實現
Java代碼如下:
package org.algorithm; import java.util.ArrayList; import java.util.Random; /** * K均值聚類算法 */ public class Kmeans { private int k; // 分成多少簇 private int m; // 迭代次數 private int dataSetLength; // 數據集元素個數,即數據集的長度 private ArrayList<float[]> dataSet; // 數據集鏈表 private ArrayList<float[]> center; // 中心鏈表 private ArrayList<ArrayList<float[]>> cluster; // 簇 private ArrayList<float> jc; // 誤差平方和,k越接近dataSetLength,誤差越小 private Random random; /** * 設置需分組的原始數據集 * * @param dataSet */ public void setDataSet(ArrayList<float[]> dataSet) { this.dataSet = dataSet; } /** * 獲取結果分組 * * @return 結果集 */ public ArrayList<ArrayList<float[]>> getCluster() { return cluster; } /** * 構造函數,傳入需要分成的簇數量 * * @param k * 簇數量,若k<=0時,設置為1,若k大于數據源的長度時,置為數據源的長度 */ public Kmeans(int k) { if (k <= 0) { k = 1; } this.k = k; } /** * 初始化 */ private void init() { m = 0; random = new Random(); if (dataSet == null || dataSet.size() == 0) { initDataSet(); } dataSetLength = dataSet.size(); if (k > dataSetLength) { k = dataSetLength; } center = initCenters(); cluster = initCluster(); jc = new ArrayList<float>(); } /** * 如果調用者未初始化數據集,則采用內部測試數據集 */ private void initDataSet() { dataSet = new ArrayList<float[]>(); // 其中{6,3}是一樣的,所以長度為15的數據集分成14簇和15簇的誤差都為0 float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 }, { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 }, { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } }; for (int i = 0; i < dataSetArray.length; i++) { dataSet.add(dataSetArray[i]); } } /** * 初始化中心數據鏈表,分成多少簇就有多少個中心點 * * @return 中心點集 */ private ArrayList<float[]> initCenters() { ArrayList<float[]> center = new ArrayList<float[]>(); int[] randoms = new int[k]; Boolean flag; int temp = random.nextint(dataSetLength); randoms[0] = temp; for (int i = 1; i < k; i++) { flag = true; while (flag) { temp = random.nextint(dataSetLength); int j = 0; // 不清楚for循環導致j無法加1 // for(j=0;j<i;++j) // { // if(temp==randoms[j]); // { // break; // } // } while (j < i) { if (temp == randoms[j]) { break; } j++; } if (j == i) { flag = false; } } randoms[i] = temp; } // 測試隨機數生成情況 // for(int i=0;i<k;i++) // { // System.out.println("test1:randoms["+i+"]="+randoms[i]); // } // System.out.println(); for (int i = 0; i < k; i++) { center.add(dataSet.get(randoms[i])); // 生成初始化中心鏈表 } return center; } /** * 初始化簇集合 * * @return 一個分為k簇的空數據的簇集合 */ private ArrayList<ArrayList<float[]>> initCluster() { ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>(); for (int i = 0; i < k; i++) { cluster.add(new ArrayList<float[]>()); } return cluster; } /** * 計算兩個點之間的距離 * * @param element * 點1 * @param center * 點2 * @return 距離 */ private float distance(float[] element, float[] center) { float distance = 0.0f; float x = element[0] - center[0]; float y = element[1] - center[1]; float z = x * x + y * y; distance = (float) Math.sqrt(z); return distance; } /** * 獲取距離集合中最小距離的位置 * * @param distance * 距離數組 * @return 最小距離在距離數組中的位置 */ private int minDistance(float[] distance) { float minDistance = distance[0]; int minLocation = 0; for (int i = 1; i < distance.length; i++) { if (distance[i] < minDistance) { minDistance = distance[i]; minLocation = i; } else if (distance[i] == minDistance) // 如果相等,隨機返回一個位置 { if (random.nextint(10) < 5) { minLocation = i; } } } return minLocation; } /** * 核心,將當前元素放到最小距離中心相關的簇中 */ private void clusterSet() { float[] distance = new float[k]; for (int i = 0; i < dataSetLength; i++) { for (int j = 0; j < k; j++) { distance[j] = distance(dataSet.get(i), center.get(j)); // System.out.println("test2:"+"dataSet["+i+"],center["+j+"],distance="+distance[j]); } int minLocation = minDistance(distance); // System.out.println("test3:"+"dataSet["+i+"],minLocation="+minLocation); // System.out.println(); cluster.get(minLocation).add(dataSet.get(i)); // 核心,將當前元素放到最小距離中心相關的簇中 } } /** * 求兩點誤差平方的方法 * * @param element * 點1 * @param center * 點2 * @return 誤差平方 */ private float errorSquare(float[] element, float[] center) { float x = element[0] - center[0]; float y = element[1] - center[1]; float errSquare = x * x + y * y; return errSquare; } /** * 計算誤差平方和準則函數方法 */ private void countRule() { float jcF = 0; for (int i = 0; i < cluster.size(); i++) { for (int j = 0; j < cluster.get(i).size(); j++) { jcF += errorSquare(cluster.get(i).get(j), center.get(i)); } } jc.add(jcF); } /** * 設置新的簇中心方法 */ private void setNewCenter() { for (int i = 0; i < k; i++) { int n = cluster.get(i).size(); if (n != 0) { float[] newCenter = { 0, 0 }; for (int j = 0; j < n; j++) { newCenter[0] += cluster.get(i).get(j)[0]; newCenter[1] += cluster.get(i).get(j)[1]; } // 設置一個平均值 newCenter[0] = newCenter[0] / n; newCenter[1] = newCenter[1] / n; center.set(i, newCenter); } } } /** * 打印數據,測試用 * * @param dataArray * 數據集 * @param dataArrayName * 數據集名稱 */ public void printDataArray(ArrayList<float[]> dataArray, String dataArrayName) { for (int i = 0; i < dataArray.size(); i++) { System.out.println("print:" + dataArrayName + "[" + i + "]={" + dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}"); } System.out.println("==================================="); } /** * Kmeans算法核心過程方法 */ private void kmeans() { init(); // printDataArray(dataSet,"initDataSet"); // printDataArray(center,"initCenter"); // 循環分組,直到誤差不變為止 while (true) { clusterSet(); // for(int i=0;i<cluster.size();i++) // { // printDataArray(cluster.get(i),"cluster["+i+"]"); // } countRule(); // System.out.println("count:"+"jc["+m+"]="+jc.get(m)); // System.out.println(); // 誤差不變了,分組完成 if (m != 0) { if (jc.get(m) - jc.get(m - 1) == 0) { break; } } setNewCenter(); // printDataArray(center,"newCenter"); m++; cluster.clear(); cluster = initCluster(); } // System.out.println("note:the times of repeat:m="+m);//輸出迭代次數 } /** * 執行算法 */ public void execute() { long startTime = System.currentTimeMillis(); System.out.println("kmeans begins"); kmeans(); long endTime = System.currentTimeMillis(); System.out.println("kmeans running time=" + (endTime - startTime) + "ms"); System.out.println("kmeans ends"); System.out.println(); } }
5.說明:
具體代碼是從網上找的,根據自己的理解加了注釋和進行部分修改,若注釋有誤還望指正
6.測試
package org.test; import java.util.ArrayList; import org.algorithm.Kmeans; public class KmeansTest { public static void main(String[] args) { //初始化一個Kmean對象,將k置為10 Kmeans k=new Kmeans(10); ArrayList<float[]> dataSet=new ArrayList<float[]>(); dataSet.add(new float[]{1,2}); dataSet.add(new float[]{3,3}); dataSet.add(new float[]{3,4}); dataSet.add(new float[]{5,6}); dataSet.add(new float[]{8,9}); dataSet.add(new float[]{4,5}); dataSet.add(new float[]{6,4}); dataSet.add(new float[]{3,9}); dataSet.add(new float[]{5,9}); dataSet.add(new float[]{4,2}); dataSet.add(new float[]{1,9}); dataSet.add(new float[]{7,8}); //設置原始數據集 k.setDataSet(dataSet); //執行算法 k.execute(); //得到聚類結果 ArrayList<ArrayList<float[]>> cluster=k.getCluster(); //查看結果 for (int i=0;i<cluster.size();i++) { k.printDataArray(cluster.get(i), "cluster["+i+"]"); } } }
總結:測試代碼已經通過。并對聚類的結果進行了查看,結果基本上符合要求。至于有沒有更精確的算法有待發現。具體的實踐還有待挖掘
總結
以上就是本文關于K均值聚類算法的Java版實現代碼示例的全部內容,希望對大家有所幫助。感興趣的朋友可以繼續參閱本站其他相關專題。如有不足之處,歡迎留言指出。感謝朋友們對本站的支持!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。