| 1 | package edu.berkeley.compbio.jlibsvm; |
|---|
| 2 | |
|---|
| 3 | import com.davidsoergel.dsutils.collections.MappingIterator; |
|---|
| 4 | import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel; |
|---|
| 5 | import edu.berkeley.compbio.jlibsvm.scaler.ScalingModel; |
|---|
| 6 | import org.jetbrains.annotations.NotNull; |
|---|
| 7 | |
|---|
| 8 | import java.util.ArrayList; |
|---|
| 9 | import java.util.Collections; |
|---|
| 10 | import java.util.HashSet; |
|---|
| 11 | import java.util.Iterator; |
|---|
| 12 | import java.util.List; |
|---|
| 13 | import java.util.Map; |
|---|
| 14 | import java.util.Set; |
|---|
| 15 | |
|---|
| 16 | /** |
|---|
| 17 | * @author <a href="mailto:dev@davidsoergel.com">David Soergel</a> |
|---|
| 18 | * @version $Id$ |
|---|
| 19 | */ |
|---|
| 20 | public abstract class ExplicitSvmProblemImpl<L extends Comparable, P, R extends SvmProblem<L, P, R>> |
|---|
| 21 | extends AbstractSvmProblem<L, P, R> implements ExplicitSvmProblem<L, P, R> |
|---|
| 22 | { |
|---|
| 23 | // ------------------------------ FIELDS ------------------------------ |
|---|
| 24 | |
|---|
| 25 | public Map<P, L> examples; |
|---|
| 26 | public Map<P, Integer> exampleIds; // maintain a known order |
|---|
| 27 | |
|---|
| 28 | |
|---|
| 29 | public ScalingModel<P> scalingModel = new NoopScalingModel<P>(); |
|---|
| 30 | //protected int numExamples = 0; |
|---|
| 31 | |
|---|
| 32 | /** |
|---|
| 33 | * the unique set of targetvalues, in a defined order avoid populating for regression! OK, regression should never |
|---|
| 34 | * call getLabels(), then. |
|---|
| 35 | */ |
|---|
| 36 | protected List<L> labels = null; |
|---|
| 37 | |
|---|
| 38 | |
|---|
| 39 | // --------------------------- CONSTRUCTORS --------------------------- |
|---|
| 40 | |
|---|
| 41 | protected ExplicitSvmProblemImpl(Map<P, L> examples, @NotNull Map<P, Integer> exampleIds) |
|---|
| 42 | { |
|---|
| 43 | this.examples = examples; |
|---|
| 44 | this.exampleIds = exampleIds; |
|---|
| 45 | } |
|---|
| 46 | |
|---|
| 47 | protected ExplicitSvmProblemImpl(Map<P, L> examples, @NotNull Map<P, Integer> exampleIds, |
|---|
| 48 | @NotNull ScalingModel<P> scalingModel) |
|---|
| 49 | { |
|---|
| 50 | this.examples = examples; |
|---|
| 51 | this.exampleIds = exampleIds; |
|---|
| 52 | this.scalingModel = scalingModel; |
|---|
| 53 | } |
|---|
| 54 | |
|---|
| 55 | protected ExplicitSvmProblemImpl(Map<P, L> examples, @NotNull Map<P, Integer> exampleIds, |
|---|
| 56 | @NotNull ScalingModel<P> scalingModel, Set<P> heldOutPoints) |
|---|
| 57 | { |
|---|
| 58 | this.examples = examples; |
|---|
| 59 | this.exampleIds = exampleIds; |
|---|
| 60 | this.scalingModel = scalingModel; |
|---|
| 61 | this.heldOutPoints = heldOutPoints; |
|---|
| 62 | } |
|---|
| 63 | |
|---|
| 64 | // --------------------- GETTER / SETTER METHODS --------------------- |
|---|
| 65 | |
|---|
| 66 | @NotNull |
|---|
| 67 | public Map<P, Integer> getExampleIds() |
|---|
| 68 | { |
|---|
| 69 | return exampleIds; |
|---|
| 70 | } |
|---|
| 71 | |
|---|
| 72 | @NotNull |
|---|
| 73 | public Map<P, L> getExamples() |
|---|
| 74 | { |
|---|
| 75 | return examples; |
|---|
| 76 | } |
|---|
| 77 | |
|---|
| 78 | public List<L> getLabels() |
|---|
| 79 | { |
|---|
| 80 | if (labels == null) |
|---|
| 81 | { |
|---|
| 82 | if (examples.isEmpty()) |
|---|
| 83 | { |
|---|
| 84 | return null; |
|---|
| 85 | } |
|---|
| 86 | Set<L> uniq = new HashSet<L>(examples.values()); |
|---|
| 87 | labels = new ArrayList<L>(uniq); |
|---|
| 88 | Collections.sort(labels); |
|---|
| 89 | } |
|---|
| 90 | return labels; |
|---|
| 91 | } |
|---|
| 92 | |
|---|
| 93 | @NotNull |
|---|
| 94 | public ScalingModel<P> getScalingModel() |
|---|
| 95 | { |
|---|
| 96 | return scalingModel; |
|---|
| 97 | } |
|---|
| 98 | |
|---|
| 99 | // ------------------------ INTERFACE METHODS ------------------------ |
|---|
| 100 | |
|---|
| 101 | |
|---|
| 102 | // --------------------- Interface ExplicitSvmProblem --------------------- |
|---|
| 103 | |
|---|
| 104 | public Iterator<R> makeFolds(int numberOfFolds) |
|---|
| 105 | { |
|---|
| 106 | // Set<R> result = new HashSet<R>(); |
|---|
| 107 | |
|---|
| 108 | List<P> points = new ArrayList<P>(getExamples().keySet()); |
|---|
| 109 | |
|---|
| 110 | Collections.shuffle(points); |
|---|
| 111 | |
|---|
| 112 | // PERF this is maybe overwrought, but ensures the best possible balance among folds (unlike examples.size() / numberOfFolds) |
|---|
| 113 | |
|---|
| 114 | List<Set<P>> heldOutPointSets = new ArrayList<Set<P>>(); |
|---|
| 115 | for (int i = 0; i < numberOfFolds; i++) |
|---|
| 116 | { |
|---|
| 117 | heldOutPointSets.add(new HashSet<P>()); |
|---|
| 118 | } |
|---|
| 119 | |
|---|
| 120 | int f = 0; |
|---|
| 121 | for (P point : points) |
|---|
| 122 | { |
|---|
| 123 | heldOutPointSets.get(f).add(point); |
|---|
| 124 | f++; |
|---|
| 125 | f %= numberOfFolds; |
|---|
| 126 | } |
|---|
| 127 | |
|---|
| 128 | |
|---|
| 129 | Iterator<R> foldIterator = new MappingIterator<Set<P>, R>(heldOutPointSets.iterator()) |
|---|
| 130 | { |
|---|
| 131 | public R function(Set<P> p) |
|---|
| 132 | { |
|---|
| 133 | return makeFold(p); |
|---|
| 134 | } |
|---|
| 135 | }; |
|---|
| 136 | return foldIterator; |
|---|
| 137 | |
|---|
| 138 | /* |
|---|
| 139 | for (Set<P> heldOutPoints : heldOutPointSets) |
|---|
| 140 | { |
|---|
| 141 | result.add(makeFold(heldOutPoints)); |
|---|
| 142 | } |
|---|
| 143 | |
|---|
| 144 | return result;*/ |
|---|
| 145 | } |
|---|
| 146 | |
|---|
| 147 | // --------------------- Interface SvmProblem --------------------- |
|---|
| 148 | |
|---|
| 149 | public int getId(P key) |
|---|
| 150 | { |
|---|
| 151 | return exampleIds.get(key); |
|---|
| 152 | } |
|---|
| 153 | |
|---|
| 154 | public L getTargetValue(P point) |
|---|
| 155 | { |
|---|
| 156 | return examples.get(point); |
|---|
| 157 | } |
|---|
| 158 | |
|---|
| 159 | public int getNumExamples() |
|---|
| 160 | { |
|---|
| 161 | |
|---|
| 162 | return examples.size(); |
|---|
| 163 | } |
|---|
| 164 | /* |
|---|
| 165 | public R asR() |
|---|
| 166 | { |
|---|
| 167 | return (R) this; |
|---|
| 168 | } |
|---|
| 169 | */ |
|---|
| 170 | // -------------------------- OTHER METHODS -------------------------- |
|---|
| 171 | |
|---|
| 172 | protected Set<P> heldOutPoints = new HashSet<P>(); |
|---|
| 173 | //protected Map<P, L> subtractionMap; |
|---|
| 174 | |
|---|
| 175 | public Set<P> getHeldOutPoints() |
|---|
| 176 | { |
|---|
| 177 | return heldOutPoints; |
|---|
| 178 | } |
|---|
| 179 | |
|---|
| 180 | |
|---|
| 181 | protected abstract R makeFold(Set<P> heldOutPoints); |
|---|
| 182 | } |
|---|