*Joyful Bioinfo가 직접 작성하였습니다.
import java.io.*; // Import the File class
import java.util.*; // Import the Scanner class to read text files
public class Kmeans {
public static List<List<String>> ReadFile(String input) throws IOException{
FileReader myReader = new FileReader(input);
BufferedReader bufReader = new BufferedReader(myReader);
List<List<String>> Data = new ArrayList<List<String>>();
String line = "";
while ((line = bufReader.readLine()) != null) {
String[] Splitline = line.split("\t");
List<String> temp = new ArrayList<String>();
for(String tx:Splitline) {
temp.add(tx);
}
Data.add(temp);
}
bufReader.close();
return Data;
}
public static double findDeviation(List<Double> nums)
{ double mean = findMean(nums);
double squareSum = 0;
for (int i = 0; i < nums.size(); i++) {
squareSum += Math.pow(nums.get(i) - mean, 2); }
return Math.sqrt((squareSum) / (nums.size() - 1)); }
public static double findMean(List<Double> nums){
double sum = 0;
for (int i = 0; i < nums.size(); i++) {
sum += nums.get(i); }
return sum / nums.size(); }
public static List<List<Double>> Nor(List<List<String>> Data){
int nRow = Data.size();
int nCol = Data.get(0).size();
List<Double>mean = new ArrayList<Double>();
List<Double>sd = new ArrayList<Double>();
for(int i=1; i<nCol;i++) {
List<Double>ColData = new ArrayList<Double>();
for(int j=1; j<nRow;j++) {
String Value = Data.get(j).get(i);
double d = Double.parseDouble(Value);
ColData.add(d);
}
double Mean = findMean(ColData);
double SD = findDeviation(ColData);
mean.add(Mean);
sd.add(SD);
}
//System.out.println(Arrays.toString(mean.toArray()));
///**** Nor ***///
List<List<Double>> NorData = new ArrayList<List<Double>>();
for(int k=1; k<nCol;k++) {
List<Double>ColData2 = new ArrayList<Double>();
for(int h=1; h<nRow;h++) {
String Value2 = Data.get(h).get(k);
//System.out.println(Value2);
double d = Double.parseDouble(Value2);
double New = ((d-mean.get(k-1))/sd.get(k-1));
ColData2.add(New);
}
NorData.add(ColData2);
}
return NorData;
}
public static Set<Integer> getRandomNumberUsingNextInt(int sampleNumber, int NumbCentroid) {
Random random = new Random();
Set<Integer>set = new LinkedHashSet<Integer>();
while (set.size() < NumbCentroid) {
set.add(random.nextInt(sampleNumber));
}
// for(int j=0; j<NumbCentroid;j++) {
//RandomNumber.add(random.nextInt(sampleNumber));
//}
return set;
}
/*
public static List<List<Double>> ChooseFirstCentroid(List<List<Double>> NorData, int NumbCentroid){
List<List<Double>> MinMax = new ArrayList<List<Double>>();
int nVar = NorData.size();
List<List<Double>> Centroid_array = new ArrayList<List<Double>>();
for(int var=0; var<nVar;var++) {
List<Double>Temp = new ArrayList<Double>();
}
return Centroid_array;
}
*/
public static List<List<Double>> ChooseFirstCentroid2(List<List<Double>> NorData, Set<Integer> RandomN){
List<Integer> items = new ArrayList<Integer>();
for (Integer e : RandomN) items.add(e);
List<List<Double>> Centroid_array = new ArrayList<List<Double>>();
int nVar = NorData.size();
for(int var=0; var<nVar;var++) {
List<Double>Temp = new ArrayList<Double>();
for (int k = 0; k<items.size();k++) {
int c = items.get(k);
Temp.add(NorData.get(var).get(c));
}
Centroid_array.add(Temp);
}
return Centroid_array;
}
public static List<List<Double>> Cal_Distance(List<List<Double>> NorData, List<List<Double>> Centroid){
List<List<Double>> Distance = new ArrayList<List<Double>>();
int nCentroid = Centroid.get(0).size();
int nVar = NorData.size();
int nSample = NorData.get(0).size();
//System.out.println(Arrays.toString(Centroid.toArray()));
for(int sample=0;sample<nSample;sample++) {
List<Double>Temp = new ArrayList<Double>();
for(int nCent=0; nCent< nCentroid;nCent++) {
double DistanceSum = 0;
for(int var=0;var<nVar;var++) {
double Cent = Centroid.get(var).get(nCent);
double Norval = NorData.get(var).get(sample);
double Dis =Math.pow((Norval-Cent),2);
DistanceSum+=Dis;
}
Temp.add(DistanceSum);
}
Distance.add(Temp);
}
//System.out.println(Arrays.toString(Distance.toArray()));
return Distance;
}
public static List<List<Integer>> DecideCluster(List<List<Double>> Distance){
List<Integer>Cluster = new ArrayList<Integer>();
int nSample = Distance.size();
int nCluster = Distance.get(1).size();
for(int i=0;i<nSample;i++) {
double min = Distance.get(i).get(0);
int min_index = 1;
for(int j=0;j<Distance.get(i).size();j++) {
if(min>Distance.get(i).get(j)) {
min = Distance.get(i).get(j);
min_index = j+1;}
}
Cluster.add(min_index);
}
//System.out.println(Arrays.toString(Cluster.toArray()));
List<List<Integer>>Cluster_Re = new ArrayList<List<Integer>>();
for(int j =1; j <nCluster+1; j++) {
List<Integer>Temp = new ArrayList<Integer>();
for(int n =0; n<Cluster.size();n++) {
int ClusterNumber = Cluster.get(n);
if (j==ClusterNumber) {
Temp.add(n);
} }
Cluster_Re.add(Temp);
}
return Cluster_Re;
}
public static double SumDistance_toWCSS(List<List<Double>> Distance){
List<Integer>Cluster = new ArrayList<Integer>();
int nSample = Distance.size();
// int nCluster = Distance.get(1).size();
double Sum = 0;
for(int i=0;i<nSample;i++) {
double min = Distance.get(i).get(0);
int min_index = 1;
for(int j=0;j<Distance.get(i).size();j++) {
if(min>Distance.get(i).get(j)) {
min = Distance.get(i).get(j);
}
}
//System.out.println(min);
Sum+=min;
Cluster.add(min_index);
}
return Sum;
}
public static List<List<Double>> Make_Centorid_usingMiddle(List<List<Double>> NorData, List<List<Integer>> Cluster){
List<List<Double>> Centroid = new ArrayList<List<Double>>();
int nVar = NorData.size();
int nCluster = Cluster.size();
for(int Var=0;Var <nVar; Var++) {
List<Double>Temp_Var = new ArrayList<Double>();
for(int Clus = 0 ; Clus<nCluster;Clus++) {
double Sum =0;
for(int i =0; i<Cluster.get(Clus).size();i++) {
int SampleNumber = Cluster.get(Clus).get(i);
double Nor_Value = NorData.get(Var).get(SampleNumber);
Sum += Nor_Value;
}
double Average = Sum /(Cluster.get(Clus).size());
Temp_Var.add(Average);
}
Centroid.add(Temp_Var);
}
return Centroid;
}
public static List<List<Integer>> RecursedFunction(List<List<Double>> Data_Nor, List<List<Integer>> Cluster) {
List<List<Integer>> Cluster_Re = Combine_Distance_Cluster(Data_Nor,Cluster);
//System.out.println(Arrays.toString(Cluster_Re.toArray()));
if ((Cluster_Re.equals(Cluster))==false) return RecursedFunction(Data_Nor, Cluster_Re);
else return Cluster_Re;
}
public static List<List<Integer>> Combine_Distance_Cluster(List<List<Double>> Data_Nor, List<List<Integer>> Cluster){
List<List<Double>> Centroid = Make_Centorid_usingMiddle( Data_Nor, Cluster);
//System.out.println(Arrays.toString(Centroid.toArray()));
List<List<Double>> Distance = Cal_Distance(Data_Nor, Centroid);
//System.out.println(Arrays.toString(Distance.toArray()));
List<List<Integer>> Cluster_New = DecideCluster(Distance);
return Cluster_New;
}
public static double Calculate_WCSS(List<List<Double>> Data_Nor, List<List<Integer>> Cluster_Final){
List<List<Double>> Centroid = Make_Centorid_usingMiddle( Data_Nor, Cluster_Final);
List<List<Double>> Distance = Cal_Distance(Data_Nor, Centroid);
double Sum = SumDistance_toWCSS(Distance);
return Sum;
}
public static void WriteFile_FromArrayList(String output,List<List<String>> Data, List<List<Integer>> Cluster) throws IOException{
FileWriter outfile = new FileWriter(output);
int nCluster= Cluster.size();
int nSample = Data.size();
for(int z =0; z< nCluster;z++){
outfile.write("Cluster"+(z+1)+"\n");
for(int i = 0; i<Cluster.get(z).size();i++) {
int Index = Cluster.get(z).get(i)+1;
String SampleName = Data.get(Index).get(0);
outfile.write(SampleName+"\n");
}
outfile.write("\n");
}
outfile.close();
}
public static void main(String args[]) throws IOException{
//String input = "Ex.txt";
String input = args[0];
System.out.println("k\tWCSS\tAIC");
//int RowCol[] = ReadFile_Count(input);
List<List<String>> Data = ReadFile(input);
List<List<Double>> Data_Nor = Nor(Data);
//int nMin_Cen = 3;
//int nMax_Cen = 10;
int nMin_Cen = Integer.parseInt(args[1]);
int nMax_Cen = Integer.parseInt(args[2]);
int nVar = Data_Nor.size();
for(int g = nMin_Cen; g <(nMax_Cen+1);g++) {
//System.out.println(Arrays.toString(Data.toArray()));
int NumbCentroid = g;
int SampleNumbers = Data_Nor.get(0).size();
List<Double>WCSS_list = new ArrayList<Double>();
List<List<List<Integer>>> Cluster_list = new ArrayList<List<List<Integer>>>();
int nIterations = NumbCentroid*300;
for(int i = 0; i< nIterations;i++) {
Set<Integer> RandomNumber = getRandomNumberUsingNextInt(SampleNumbers,NumbCentroid);
//System.out.println(Arrays.toString(RandomNumber.toArray()));
List<List<Double>> Centroid = ChooseFirstCentroid2(Data_Nor, RandomNumber);
//List<List<Double>> Centroid = ChooseFirstCentroid(Data_Nor, NumbCentroid);
//System.out.println(Arrays.toString(Centroid.toArray()));
List<List<Double>> Distance = Cal_Distance(Data_Nor, Centroid);
//System.out.println(Arrays.toString(Distance.toArray()));
List<List<Integer>> Cluster = DecideCluster(Distance);
//System.out.println(Arrays.toString(Cluster.toArray()));
/*
List<List<Double>> Centroid2 = Make_Centorid_usingMiddle(Data_Nor,Cluster);
List<List<Double>> Distance2 = Cal_Distance(Data_Nor, Centroid2);
//System.out.println(Arrays.toString(Distance2.toArray()));
List<List<Integer>> Cluster2 = DecideCluster(Distance2);
System.out.println(Arrays.toString(Cluster2.toArray()));
*/
////*************New Start ****////////
//List<List<Double>> Centroid2 = Make_Centorid_usingMiddle( Data_Nor, Cluster);
List<List<Integer>> Cluster2 = Combine_Distance_Cluster( Data_Nor,Cluster);
//System.out.println(Arrays.toString(Cluster2.toArray()));
List<List<Integer>> Cluster_Final = RecursedFunction(Data_Nor, Cluster2);
//System.out.println(Arrays.toString(Cluster.toArray()));
double WCSS = Calculate_WCSS(Data_Nor, Cluster_Final);
WCSS_list.add(WCSS);
Cluster_list.add(Cluster_Final);
//System.out.println(WCSS);
//System.out.println(Arrays.toString(Cluster_Final.toArray()));
}
double min = WCSS_list.get(0);
int min_index = 0;
for(int i=0;i<WCSS_list.size();i++) {
if(min>WCSS_list.get(i)) {
min = WCSS_list.get(i);
min_index = i;
}
}
System.out.println(g+"\t"+min+"\t"+(min+(nVar*NumbCentroid*2)));
List<List<Integer>> Cluster_Final_Final = Cluster_list.get(min_index);
WriteFile_FromArrayList("Cluster"+g+"_"+input, Data, Cluster_Final_Final);
}
}
}