*Joyful Bioinfo가 직접 작성하였습니다.

import java.io.*;  // Import the File class
import java.util.*; // Import the Scanner class to read text files

public class Kmeans {
public static List<List<String>> ReadFile(String input) throws IOException{
    FileReader myReader = new FileReader(input);
    BufferedReader bufReader = new BufferedReader(myReader);

    List<List<String>> Data = new ArrayList<List<String>>();  
    String line = "";
    while ((line = bufReader.readLine()) != null) {
    String[] Splitline = line.split("\t");
    List<String> temp = new ArrayList<String>();
    for(String tx:Splitline) {
    temp.add(tx);
    }
    Data.add(temp);
    }
    bufReader.close();     
    return Data;
  }
 
   public static double findDeviation(List<Double> nums)
    {   double mean = findMean(nums);
        double squareSum = 0;
        for (int i = 0; i < nums.size(); i++) {  
        squareSum += Math.pow(nums.get(i) - mean, 2);   }
        return Math.sqrt((squareSum) / (nums.size() - 1));  } 
   
    public static double findMean(List<Double> nums){   
    double sum = 0;
        for (int i = 0; i < nums.size(); i++) {
            sum += nums.get(i);  }
        return sum / nums.size(); } 

public static List<List<Double>> Nor(List<List<String>> Data){
int nRow = Data.size();
int nCol = Data.get(0).size();
List<Double>mean = new ArrayList<Double>();  
List<Double>sd = new ArrayList<Double>();  
for(int i=1; i<nCol;i++) {
List<Double>ColData = new ArrayList<Double>();  
  for(int j=1; j<nRow;j++) {
String Value = Data.get(j).get(i); 
double d = Double.parseDouble(Value);
ColData.add(d);
}
  double Mean = findMean(ColData);
  double SD = findDeviation(ColData);
  mean.add(Mean);
  sd.add(SD);
//System.out.println(Arrays.toString(mean.toArray()));
///**** Nor ***///
List<List<Double>> NorData = new ArrayList<List<Double>>();  
for(int k=1; k<nCol;k++) {
List<Double>ColData2 = new ArrayList<Double>();  
  for(int h=1; h<nRow;h++) {
String Value2 = Data.get(h).get(k);
//System.out.println(Value2);
double d = Double.parseDouble(Value2);
double New = ((d-mean.get(k-1))/sd.get(k-1));
ColData2.add(New);
}
    NorData.add(ColData2);
}
return NorData;
}
public static Set<Integer> getRandomNumberUsingNextInt(int sampleNumber, int NumbCentroid) {
    Random random = new Random();
    Set<Integer>set = new LinkedHashSet<Integer>();
    while (set.size() < NumbCentroid) {
    set.add(random.nextInt(sampleNumber));
    }
   // for(int j=0; j<NumbCentroid;j++) {
    //RandomNumber.add(random.nextInt(sampleNumber));
    //}
    return set;
}
/*
public static List<List<Double>> ChooseFirstCentroid(List<List<Double>> NorData, int NumbCentroid){
List<List<Double>> MinMax = new ArrayList<List<Double>>();  
int nVar = NorData.size();
List<List<Double>> Centroid_array = new ArrayList<List<Double>>();  
 
for(int var=0; var<nVar;var++) {
List<Double>Temp = new ArrayList<Double>();  
 
  }
 
return Centroid_array;
}
*/
public static List<List<Double>> ChooseFirstCentroid2(List<List<Double>> NorData, Set<Integer> RandomN){
List<Integer> items = new ArrayList<Integer>();
for (Integer e : RandomN) items.add(e);
 
List<List<Double>> Centroid_array = new ArrayList<List<Double>>();  
int nVar = NorData.size();
for(int var=0; var<nVar;var++) {
List<Double>Temp = new ArrayList<Double>();  
for (int k = 0; k<items.size();k++) {
int c = items.get(k);
Temp.add(NorData.get(var).get(c));
}
Centroid_array.add(Temp);
  }
 
return Centroid_array;
}
public static List<List<Double>> Cal_Distance(List<List<Double>> NorData, List<List<Double>> Centroid){
List<List<Double>> Distance = new ArrayList<List<Double>>();  
int nCentroid = Centroid.get(0).size();
int nVar = NorData.size();
int nSample = NorData.get(0).size();
//System.out.println(Arrays.toString(Centroid.toArray()));
for(int sample=0;sample<nSample;sample++) {

List<Double>Temp = new ArrayList<Double>();
    for(int nCent=0; nCent< nCentroid;nCent++) {
double DistanceSum = 0;
for(int var=0;var<nVar;var++) {
double Cent = Centroid.get(var).get(nCent);
double Norval = NorData.get(var).get(sample);
    double Dis =Math.pow((Norval-Cent),2);
    DistanceSum+=Dis;
}
Temp.add(DistanceSum);
}
Distance.add(Temp);
}
//System.out.println(Arrays.toString(Distance.toArray()));
return Distance;
}  
 
public static List<List<Integer>>  DecideCluster(List<List<Double>> Distance){
List<Integer>Cluster = new ArrayList<Integer>();
int nSample = Distance.size();
int nCluster = Distance.get(1).size();
 
for(int i=0;i<nSample;i++) {
double min = Distance.get(i).get(0);
int min_index = 1;
for(int j=0;j<Distance.get(i).size();j++) {
if(min>Distance.get(i).get(j)) {
min = Distance.get(i).get(j);
min_index = j+1;}
}
Cluster.add(min_index);
}
//System.out.println(Arrays.toString(Cluster.toArray()));
List<List<Integer>>Cluster_Re = new ArrayList<List<Integer>>();
for(int j =1; j <nCluster+1; j++) {
List<Integer>Temp = new ArrayList<Integer>();
for(int n =0; n<Cluster.size();n++) {
int ClusterNumber = Cluster.get(n);
if (j==ClusterNumber) {
Temp.add(n);
} }
     Cluster_Re.add(Temp);
}
return Cluster_Re;
}
 
public static double SumDistance_toWCSS(List<List<Double>> Distance){
List<Integer>Cluster = new ArrayList<Integer>();
int nSample = Distance.size();
// int nCluster = Distance.get(1).size();
double Sum = 0;
for(int i=0;i<nSample;i++) {
double min = Distance.get(i).get(0);
int min_index = 1;
for(int j=0;j<Distance.get(i).size();j++) {
if(min>Distance.get(i).get(j)) {
min = Distance.get(i).get(j);
}
}
//System.out.println(min);
  Sum+=min;
Cluster.add(min_index);
}
 
return Sum;
}
public static List<List<Double>> Make_Centorid_usingMiddle(List<List<Double>> NorData, List<List<Integer>> Cluster){
List<List<Double>> Centroid = new ArrayList<List<Double>>();
int nVar = NorData.size();
int nCluster = Cluster.size();
for(int Var=0;Var <nVar; Var++) {
List<Double>Temp_Var = new ArrayList<Double>();
for(int Clus = 0 ; Clus<nCluster;Clus++) {
double Sum =0;
for(int i =0; i<Cluster.get(Clus).size();i++) {
int SampleNumber = Cluster.get(Clus).get(i);
double Nor_Value = NorData.get(Var).get(SampleNumber);
Sum += Nor_Value;
}
   double Average = Sum /(Cluster.get(Clus).size());
   Temp_Var.add(Average);
}
Centroid.add(Temp_Var);
}
return Centroid;
}
public static List<List<Integer>> RecursedFunction(List<List<Double>> Data_Nor, List<List<Integer>> Cluster) {
List<List<Integer>> Cluster_Re = Combine_Distance_Cluster(Data_Nor,Cluster);
//System.out.println(Arrays.toString(Cluster_Re.toArray()));
if ((Cluster_Re.equals(Cluster))==false) return RecursedFunction(Data_Nor, Cluster_Re); 
else return Cluster_Re;
}
 
public static List<List<Integer>> Combine_Distance_Cluster(List<List<Double>> Data_Nor, List<List<Integer>> Cluster){
List<List<Double>> Centroid = Make_Centorid_usingMiddle( Data_Nor, Cluster);
//System.out.println(Arrays.toString(Centroid.toArray()));
List<List<Double>> Distance = Cal_Distance(Data_Nor, Centroid);
//System.out.println(Arrays.toString(Distance.toArray()));
List<List<Integer>> Cluster_New = DecideCluster(Distance);
return Cluster_New;
}
 
public static double Calculate_WCSS(List<List<Double>> Data_Nor, List<List<Integer>> Cluster_Final){
List<List<Double>> Centroid = Make_Centorid_usingMiddle( Data_Nor, Cluster_Final);
List<List<Double>> Distance = Cal_Distance(Data_Nor, Centroid); 
double Sum = SumDistance_toWCSS(Distance);
return Sum;
}
 
public static void WriteFile_FromArrayList(String output,List<List<String>> Data, List<List<Integer>> Cluster) throws IOException{
     FileWriter outfile = new FileWriter(output);
     int nCluster= Cluster.size();
     int nSample = Data.size();
     for(int z =0; z< nCluster;z++){
      outfile.write("Cluster"+(z+1)+"\n");
      for(int i = 0; i<Cluster.get(z).size();i++) {
      int Index = Cluster.get(z).get(i)+1;
      String SampleName = Data.get(Index).get(0);
      outfile.write(SampleName+"\n");
      }
      outfile.write("\n");
     }
     outfile.close(); 
}
public static void main(String args[]) throws IOException{
    //String input = "Ex.txt";
    String input = args[0]; 
    System.out.println("k\tWCSS\tAIC");
    //int RowCol[] = ReadFile_Count(input);
    List<List<String>> Data = ReadFile(input); 
    List<List<Double>> Data_Nor = Nor(Data);
    //int nMin_Cen = 3;
    //int nMax_Cen = 10;
    int nMin_Cen = Integer.parseInt(args[1]);
                    int nMax_Cen = Integer.parseInt(args[2]);
    int nVar = Data_Nor.size();
    for(int g = nMin_Cen; g <(nMax_Cen+1);g++) {
    //System.out.println(Arrays.toString(Data.toArray()));
    int NumbCentroid = g;
    int SampleNumbers = Data_Nor.get(0).size();
    
    List<Double>WCSS_list = new ArrayList<Double>();
    List<List<List<Integer>>> Cluster_list = new ArrayList<List<List<Integer>>>();
    int nIterations = NumbCentroid*300;
    for(int i = 0; i< nIterations;i++) {
Set<Integer> RandomNumber = getRandomNumberUsingNextInt(SampleNumbers,NumbCentroid); 
//System.out.println(Arrays.toString(RandomNumber.toArray()));
List<List<Double>> Centroid = ChooseFirstCentroid2(Data_Nor, RandomNumber);
    //List<List<Double>> Centroid = ChooseFirstCentroid(Data_Nor, NumbCentroid);
    //System.out.println(Arrays.toString(Centroid.toArray()));
    List<List<Double>> Distance = Cal_Distance(Data_Nor, Centroid);
    //System.out.println(Arrays.toString(Distance.toArray()));
    List<List<Integer>> Cluster = DecideCluster(Distance);
    //System.out.println(Arrays.toString(Cluster.toArray()));

    /*
    List<List<Double>> Centroid2 = Make_Centorid_usingMiddle(Data_Nor,Cluster);
    List<List<Double>> Distance2 = Cal_Distance(Data_Nor, Centroid2);
    //System.out.println(Arrays.toString(Distance2.toArray()));
    List<List<Integer>> Cluster2 = DecideCluster(Distance2);
    System.out.println(Arrays.toString(Cluster2.toArray()));
    */
    ////*************New Start ****////////
    //List<List<Double>> Centroid2 = Make_Centorid_usingMiddle( Data_Nor, Cluster);
    List<List<Integer>> Cluster2 = Combine_Distance_Cluster( Data_Nor,Cluster);
    //System.out.println(Arrays.toString(Cluster2.toArray()));
    List<List<Integer>> Cluster_Final = RecursedFunction(Data_Nor,  Cluster2);
    //System.out.println(Arrays.toString(Cluster.toArray()));
    double WCSS = Calculate_WCSS(Data_Nor, Cluster_Final);
    WCSS_list.add(WCSS);
    Cluster_list.add(Cluster_Final);
    //System.out.println(WCSS);
    //System.out.println(Arrays.toString(Cluster_Final.toArray()));
    }
    double min = WCSS_list.get(0); 
    int min_index = 0;
    for(int i=0;i<WCSS_list.size();i++) {
        if(min>WCSS_list.get(i)) {
    min = WCSS_list.get(i);
    min_index = i;
        }
    }
    System.out.println(g+"\t"+min+"\t"+(min+(nVar*NumbCentroid*2)));
    List<List<Integer>> Cluster_Final_Final = Cluster_list.get(min_index);
    WriteFile_FromArrayList("Cluster"+g+"_"+input, Data, Cluster_Final_Final);
}
}
}