root / trunk / src / main / java / edu / berkeley / compbio / jlibsvm / ExplicitSvmProblemImpl.java

Revision 113, 4.1 kB (checked in by soergel, 1 year ago)
Line 
1package edu.berkeley.compbio.jlibsvm;
2
3import com.davidsoergel.dsutils.collections.MappingIterator;
4import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;
5import edu.berkeley.compbio.jlibsvm.scaler.ScalingModel;
6import org.jetbrains.annotations.NotNull;
7
8import java.util.ArrayList;
9import java.util.Collections;
10import java.util.HashSet;
11import java.util.Iterator;
12import java.util.List;
13import java.util.Map;
14import java.util.Set;
15
16/**
17 * @author <a href="mailto:dev@davidsoergel.com">David Soergel</a>
18 * @version $Id$
19 */
20public abstract class ExplicitSvmProblemImpl<L extends Comparable, P, R extends SvmProblem<L, P, R>>
21                extends AbstractSvmProblem<L, P, R> implements ExplicitSvmProblem<L, P, R>
22        {
23// ------------------------------ FIELDS ------------------------------
24
25        public Map<P, L> examples;
26        public Map<P, Integer> exampleIds; // maintain a known order
27
28
29        public ScalingModel<P> scalingModel = new NoopScalingModel<P>();
30        //protected int numExamples = 0;
31
32        /**
33         * the unique set of targetvalues, in a defined order avoid populating for regression!  OK, regression should never
34         * call getLabels(), then.
35         */
36        protected List<L> labels = null;
37
38
39// --------------------------- CONSTRUCTORS ---------------------------
40
41        protected ExplicitSvmProblemImpl(Map<P, L> examples, @NotNull Map<P, Integer> exampleIds)
42                {
43                this.examples = examples;
44                this.exampleIds = exampleIds;
45                }
46
47        protected ExplicitSvmProblemImpl(Map<P, L> examples, @NotNull Map<P, Integer> exampleIds,
48                                         @NotNull ScalingModel<P> scalingModel)
49                {
50                this.examples = examples;
51                this.exampleIds = exampleIds;
52                this.scalingModel = scalingModel;
53                }
54
55        protected ExplicitSvmProblemImpl(Map<P, L> examples, @NotNull Map<P, Integer> exampleIds,
56                                         @NotNull ScalingModel<P> scalingModel, Set<P> heldOutPoints)
57                {
58                this.examples = examples;
59                this.exampleIds = exampleIds;
60                this.scalingModel = scalingModel;
61                this.heldOutPoints = heldOutPoints;
62                }
63
64// --------------------- GETTER / SETTER METHODS ---------------------
65
66        @NotNull
67        public Map<P, Integer> getExampleIds()
68                {
69                return exampleIds;
70                }
71
72        @NotNull
73        public Map<P, L> getExamples()
74                {
75                return examples;
76                }
77
78        public List<L> getLabels()
79                {
80                if (labels == null)
81                        {
82                        if (examples.isEmpty())
83                                {
84                                return null;
85                                }
86                        Set<L> uniq = new HashSet<L>(examples.values());
87                        labels = new ArrayList<L>(uniq);
88                        Collections.sort(labels);
89                        }
90                return labels;
91                }
92
93        @NotNull
94        public ScalingModel<P> getScalingModel()
95                {
96                return scalingModel;
97                }
98
99// ------------------------ INTERFACE METHODS ------------------------
100
101
102// --------------------- Interface ExplicitSvmProblem ---------------------
103
104        public Iterator<R> makeFolds(int numberOfFolds)
105                {
106//              Set<R> result = new HashSet<R>();
107
108                List<P> points = new ArrayList<P>(getExamples().keySet());
109
110                Collections.shuffle(points);
111
112                // PERF this is maybe overwrought, but ensures the best possible balance among folds (unlike examples.size() / numberOfFolds)
113
114                List<Set<P>> heldOutPointSets = new ArrayList<Set<P>>();
115                for (int i = 0; i < numberOfFolds; i++)
116                        {
117                        heldOutPointSets.add(new HashSet<P>());
118                        }
119
120                int f = 0;
121                for (P point : points)
122                        {
123                        heldOutPointSets.get(f).add(point);
124                        f++;
125                        f %= numberOfFolds;
126                        }
127
128
129                Iterator<R> foldIterator = new MappingIterator<Set<P>, R>(heldOutPointSets.iterator())
130                {
131                public R function(Set<P> p)
132                        {
133                        return makeFold(p);
134                        }
135                };
136                return foldIterator;
137
138/*
139                for (Set<P> heldOutPoints : heldOutPointSets)
140                        {
141                        result.add(makeFold(heldOutPoints));
142                        }
143
144                return result;*/
145                }
146
147// --------------------- Interface SvmProblem ---------------------
148
149        public int getId(P key)
150                {
151                return exampleIds.get(key);
152                }
153
154        public L getTargetValue(P point)
155                {
156                return examples.get(point);
157                }
158
159        public int getNumExamples()
160                {
161
162                return examples.size();
163                }
164/*
165        public R asR()
166                {
167                return (R) this;
168                }
169*/
170// -------------------------- OTHER METHODS --------------------------
171
172        protected Set<P> heldOutPoints = new HashSet<P>();
173        //protected Map<P, L> subtractionMap;
174
175        public Set<P> getHeldOutPoints()
176                {
177                return heldOutPoints;
178                }
179
180
181        protected abstract R makeFold(Set<P> heldOutPoints);
182        }
Note: See TracBrowser for help on using the browser.