Coverage details for edu.uci.ics.jung.algorithms.cluster.KMeansClusterer

LineHitsSource
1 /*
2  * Copyright (c) 2003, the JUNG Project and the Regents of the University
3  * of California
4  * All rights reserved.
5  *
6  * This software is open-source under the BSD license; see either
7  * "license.txt" or
8  * http://jung.sourceforge.net/license.txt for a description.
9  */
10 /*
11  * Created on Aug 9, 2004
12  *
13  */
14 package edu.uci.ics.jung.algorithms.cluster;
15  
16 import java.util.Arrays;
17 import java.util.Collection;
18 import java.util.HashMap;
19 import java.util.HashSet;
20 import java.util.Iterator;
21 import java.util.Map;
22 import java.util.Set;
23  
24 import cern.jet.random.engine.DRand;
25 import cern.jet.random.engine.RandomEngine;
26 import edu.uci.ics.jung.statistics.DiscreteDistribution;
27  
28  
29  
30 /**
31  * Groups Objects into a specified number of clusters, based on their
32  * proximity in d-dimensional space, using the k-means algorithm.
33  *
34  * @author Joshua O'Madadhain
35  */
36 public class KMeansClusterer
37 {
38     protected int max_iterations;
39     protected double convergence_threshold;
401    protected RandomEngine rand = new DRand();
41     
42     /**
43      * Creates an instance for which calls to <code>cluster</code> will terminate
44      * when either of the two following conditions is true:
45      * <ul>
46      * <li/>the number of iterations is > <code>max_iterations</code>
47      * <li/>none of the centroids has moved as much as <code>convergence_threshold</code>
48      * since the previous iteration
49      * </ul>
50      * @param max_iterations
51      * @param convergence_threshold
52      */
53     public KMeansClusterer(int max_iterations, double convergence_threshold)
541    {
551        if (max_iterations < 0)
560            throw new IllegalArgumentException("max iterations must be >= 0");
57         
581        if (convergence_threshold <= 0)
590            throw new IllegalArgumentException("convergence threshold " +
60                 "must be > 0");
61         
621        this.max_iterations = max_iterations;
631        this.convergence_threshold = convergence_threshold;
641    }
65     
66     /**
67      * Returns a <code>Collection</code> of clusters, where each cluster is
68      * represented as a <code>Map</code> of <code>Objects</code> to locations
69      * in d-dimensional space.
70      * @param object_locations a map of the Objects to cluster, to
71      * <code>double</code> arrays that specify their locations in d-dimensional space.
72      * @param num_clusters the number of clusters to create
73      * @throws NotEnoughClustersException
74      */
75     public Collection cluster(Map object_locations, int num_clusters)
76     {
773        if (num_clusters < 2 || num_clusters > object_locations.size())
781            throw new IllegalArgumentException("number of clusters " +
79                 "must be >= 2 and <= number of objects (" +
80                 object_locations.size() + ")");
81         
822        if (object_locations == null || object_locations.isEmpty())
830            throw new IllegalArgumentException("'objects' must be non-empty");
84  
852        Set centroids = new HashSet();
862        Object[] obj_array = object_locations.keySet().toArray();
872        Set tried = new HashSet();
88         
89         // create the specified number of clusters
9012        while (centroids.size() < num_clusters && tried.size() < object_locations.size())
91         {
9210            Object o = obj_array[(int)(rand.nextDouble() * obj_array.length)];
9310            tried.add(o);
9410            double[] mean_value = (double[])object_locations.get(o);
9510            boolean duplicate = false;
9610            for (Iterator iter = centroids.iterator(); iter.hasNext(); )
97             {
988                double[] cur = (double[])iter.next();
998                if (Arrays.equals(mean_value, cur))
1006                    duplicate = true;
101             }
10210            if (!duplicate)
1034                centroids.add(mean_value);
104         }
105         
1062        if (tried.size() >= object_locations.size())
1071            throw new NotEnoughClustersException();
108         
109         // put items in their initial clusters
1101        Map clusterMap = assignToClusters(object_locations, centroids);
111         
112         // keep reconstituting clusters until either
113         // (a) membership is stable, or
114         // (b) number of iterations passes max_iterations, or
115         // (c) max movement of any centroid is <= convergence_threshold
1161        int iterations = 0;
1171        double max_movement = Double.POSITIVE_INFINITY;
1183        while (iterations++ < max_iterations && max_movement > convergence_threshold)
119         {
1202            max_movement = 0;
1212            Set new_centroids = new HashSet();
122             // calculate new mean for each cluster
1232            for (Iterator iter = clusterMap.keySet().iterator(); iter.hasNext(); )
124             {
1254                double[] centroid = (double[])iter.next();
1264                Map elements = (Map)clusterMap.get(centroid);
1274                double[][] locations = new double[elements.size()][];
1284                int i = 0;
1294                for (Iterator e_iter = elements.keySet().iterator(); e_iter.hasNext(); )
13010                    locations[i++] = (double[])object_locations.get(e_iter.next());
131                 
1324                double[] mean = DiscreteDistribution.mean(locations);
1334                max_movement = Math.max(max_movement,
134                     Math.sqrt(DiscreteDistribution.squaredError(centroid, mean)));
1354                new_centroids.add(mean);
136             }
137             
138             // TODO: check membership of clusters: have they changed?
139  
140             // regenerate cluster membership based on means
1412            clusterMap = assignToClusters(object_locations, new_centroids);
142         }
1431        return (Collection)clusterMap.values();
144     }
145  
146     /**
147      * Assigns each object to the cluster whose centroid is closest to the
148      * object.
149      * @param object_locations a map of objects to locations
150      * @param centroids the centroids of the clusters to be formed
151      * @return a map of objects to assigned clusters
152      */
153     protected Map assignToClusters(Map object_locations, Set centroids)
154     {
1553        Map clusterMap = new HashMap();
1563        for (Iterator c_iter = centroids.iterator(); c_iter.hasNext(); )
1576            clusterMap.put(c_iter.next(), new HashMap());
158         
1593        for (Iterator o_iter = object_locations.keySet().iterator(); o_iter.hasNext(); )
160         {
16115            Object o = o_iter.next();
16215            double[] location = (double[])object_locations.get(o);
163  
164             // find the cluster with the closest centroid
16515            Iterator c_iter = centroids.iterator();
16615            double[] closest = (double[])c_iter.next();
16715            double distance = DiscreteDistribution.squaredError(location, closest);
168             
16930            while (c_iter.hasNext())
170             {
17115                double[] centroid = (double[])c_iter.next();
17215                double dist_cur = DiscreteDistribution.squaredError(location, centroid);
17315                if (dist_cur < distance)
174                 {
1758                    distance = dist_cur;
1768                    closest = centroid;
177                 }
178             }
17915            Map elements = (Map)clusterMap.get(closest);
18015            elements.put(o, location);
181         }
182         
1833        return clusterMap;
184     }
185     
186     public void setSeed(int random_seed)
187     {
1880        this.rand = new DRand(random_seed);
1890    }
190     
191     /**
192      * An exception that indicates that the specified data points cannot be
193      * clustered into the number of clusters requested by the user.
194      * This will happen if and only if there are fewer distinct points than
195      * requested clusters. (If there are fewer total data points than
196      * requested clusters, <code>IllegalArgumentException</code> will be thrown.)
197      *
198      * @author Joshua O'Madadhain
199      */
200     public static class NotEnoughClustersException extends RuntimeException
201     {
202         public String getMessage()
203         {
204             return "Not enough distinct points in the input data set to form " +
205                     "the requested number of clusters";
206         }
207     }
208 }

this report was generated by version 1.0.5 of jcoverage.
visit www.jcoverage.com for updates.

copyright © 2003, jcoverage ltd. all rights reserved.
Java is a trademark of Sun Microsystems, Inc. in the United States and other countries.