1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.hadoop.hbase.mapreduce.hadoopbackport;
20
21 import java.io.IOException;
22 import java.lang.reflect.Constructor;
23 import java.util.ArrayList;
24 import java.util.Arrays;
25 import java.util.List;
26 import java.util.Random;
27
28 import org.apache.commons.logging.Log;
29 import org.apache.commons.logging.LogFactory;
30
31 import org.apache.hadoop.conf.Configuration;
32 import org.apache.hadoop.conf.Configured;
33 import org.apache.hadoop.fs.FileSystem;
34 import org.apache.hadoop.fs.Path;
35 import org.apache.hadoop.io.NullWritable;
36 import org.apache.hadoop.io.RawComparator;
37 import org.apache.hadoop.io.SequenceFile;
38 import org.apache.hadoop.io.WritableComparable;
39 import org.apache.hadoop.mapreduce.InputFormat;
40 import org.apache.hadoop.mapreduce.InputSplit;
41 import org.apache.hadoop.mapreduce.Job;
42 import org.apache.hadoop.mapreduce.RecordReader;
43 import org.apache.hadoop.mapreduce.TaskAttemptContext;
44 import org.apache.hadoop.mapreduce.TaskAttemptID;
45 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
46 import org.apache.hadoop.util.ReflectionUtils;
47 import org.apache.hadoop.util.Tool;
48 import org.apache.hadoop.util.ToolRunner;
49
50
51
52
53
54
55
56
57
58 public class InputSampler<K,V> extends Configured implements Tool {
59
60 private static final Log LOG = LogFactory.getLog(InputSampler.class);
61
62 static int printUsage() {
63 System.out.println("sampler -r <reduces>\n" +
64 " [-inFormat <input format class>]\n" +
65 " [-keyClass <map input & output key class>]\n" +
66 " [-splitRandom <double pcnt> <numSamples> <maxsplits> | " +
67 " // Sample from random splits at random (general)\n" +
68 " -splitSample <numSamples> <maxsplits> | " +
69 " // Sample from first records in splits (random data)\n"+
70 " -splitInterval <double pcnt> <maxsplits>]" +
71 " // Sample from splits at intervals (sorted data)");
72 System.out.println("Default sampler: -splitRandom 0.1 10000 10");
73 ToolRunner.printGenericCommandUsage(System.out);
74 return -1;
75 }
76
77 public InputSampler(Configuration conf) {
78 setConf(conf);
79 }
80
81
82
83
84
85 public interface Sampler<K,V> {
86
87
88
89
90 K[] getSample(InputFormat<K,V> inf, Job job)
91 throws IOException, InterruptedException;
92 }
93
94
95
96
97
98 public static class SplitSampler<K,V> implements Sampler<K,V> {
99
100 private final int numSamples;
101 private final int maxSplitsSampled;
102
103
104
105
106
107
108
109 public SplitSampler(int numSamples) {
110 this(numSamples, Integer.MAX_VALUE);
111 }
112
113
114
115
116
117
118
119 public SplitSampler(int numSamples, int maxSplitsSampled) {
120 this.numSamples = numSamples;
121 this.maxSplitsSampled = maxSplitsSampled;
122 }
123
124
125
126
127 @SuppressWarnings("unchecked")
128 @Override
129 public K[] getSample(InputFormat<K,V> inf, Job job)
130 throws IOException, InterruptedException {
131 List<InputSplit> splits = inf.getSplits(job);
132 ArrayList<K> samples = new ArrayList<K>(numSamples);
133 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
134 int samplesPerSplit = numSamples / splitsToSample;
135 long records = 0;
136 for (int i = 0; i < splitsToSample; ++i) {
137 TaskAttemptContext samplingContext = getTaskAttemptContext(job);
138 RecordReader<K,V> reader = inf.createRecordReader(
139 splits.get(i), samplingContext);
140 reader.initialize(splits.get(i), samplingContext);
141 while (reader.nextKeyValue()) {
142 samples.add(ReflectionUtils.copy(job.getConfiguration(),
143 reader.getCurrentKey(), null));
144 ++records;
145 if ((i+1) * samplesPerSplit <= records) {
146 break;
147 }
148 }
149 reader.close();
150 }
151 return (K[])samples.toArray();
152 }
153 }
154
155
156
157
158
159
160
161
162
163
164
165
166 public static TaskAttemptContext getTaskAttemptContext(final Job job)
167 throws IOException {
168 Constructor<TaskAttemptContext> c;
169 try {
170 c = TaskAttemptContext.class.getConstructor(Configuration.class, TaskAttemptID.class);
171 } catch (Exception e) {
172 throw new IOException("Failed getting constructor", e);
173 }
174 try {
175 return c.newInstance(job.getConfiguration(), new TaskAttemptID());
176 } catch (Exception e) {
177 throw new IOException("Failed creating instance", e);
178 }
179 }
180
181
182
183
184
185
186 public static class RandomSampler<K,V> implements Sampler<K,V> {
187 private double freq;
188 private final int numSamples;
189 private final int maxSplitsSampled;
190
191
192
193
194
195
196
197
198 public RandomSampler(double freq, int numSamples) {
199 this(freq, numSamples, Integer.MAX_VALUE);
200 }
201
202
203
204
205
206
207
208
209 public RandomSampler(double freq, int numSamples, int maxSplitsSampled) {
210 this.freq = freq;
211 this.numSamples = numSamples;
212 this.maxSplitsSampled = maxSplitsSampled;
213 }
214
215
216
217
218
219
220
221 @SuppressWarnings("unchecked")
222 @Override
223 public K[] getSample(InputFormat<K,V> inf, Job job)
224 throws IOException, InterruptedException {
225 List<InputSplit> splits = inf.getSplits(job);
226 ArrayList<K> samples = new ArrayList<K>(numSamples);
227 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
228
229 Random r = new Random();
230 long seed = r.nextLong();
231 r.setSeed(seed);
232 LOG.debug("seed: " + seed);
233
234 for (int i = 0; i < splits.size(); ++i) {
235 InputSplit tmp = splits.get(i);
236 int j = r.nextInt(splits.size());
237 splits.set(i, splits.get(j));
238 splits.set(j, tmp);
239 }
240
241
242
243 for (int i = 0; i < splitsToSample ||
244 (i < splits.size() && samples.size() < numSamples); ++i) {
245 TaskAttemptContext samplingContext = getTaskAttemptContext(job);
246 RecordReader<K,V> reader = inf.createRecordReader(
247 splits.get(i), samplingContext);
248 reader.initialize(splits.get(i), samplingContext);
249 while (reader.nextKeyValue()) {
250 if (r.nextDouble() <= freq) {
251 if (samples.size() < numSamples) {
252 samples.add(ReflectionUtils.copy(job.getConfiguration(),
253 reader.getCurrentKey(), null));
254 } else {
255
256
257
258
259 int ind = r.nextInt(numSamples);
260 if (ind != numSamples) {
261 samples.set(ind, ReflectionUtils.copy(job.getConfiguration(),
262 reader.getCurrentKey(), null));
263 }
264 freq *= (numSamples - 1) / (double) numSamples;
265 }
266 }
267 }
268 reader.close();
269 }
270 return (K[])samples.toArray();
271 }
272 }
273
274
275
276
277
278 public static class IntervalSampler<K,V> implements Sampler<K,V> {
279 private final double freq;
280 private final int maxSplitsSampled;
281
282
283
284
285
286 public IntervalSampler(double freq) {
287 this(freq, Integer.MAX_VALUE);
288 }
289
290
291
292
293
294
295
296 public IntervalSampler(double freq, int maxSplitsSampled) {
297 this.freq = freq;
298 this.maxSplitsSampled = maxSplitsSampled;
299 }
300
301
302
303
304
305
306 @SuppressWarnings("unchecked")
307 @Override
308 public K[] getSample(InputFormat<K,V> inf, Job job)
309 throws IOException, InterruptedException {
310 List<InputSplit> splits = inf.getSplits(job);
311 ArrayList<K> samples = new ArrayList<K>();
312 int splitsToSample = Math.min(maxSplitsSampled, splits.size());
313 long records = 0;
314 long kept = 0;
315 for (int i = 0; i < splitsToSample; ++i) {
316 TaskAttemptContext samplingContext = getTaskAttemptContext(job);
317 RecordReader<K,V> reader = inf.createRecordReader(
318 splits.get(i), samplingContext);
319 reader.initialize(splits.get(i), samplingContext);
320 while (reader.nextKeyValue()) {
321 ++records;
322 if ((double) kept / records < freq) {
323 samples.add(ReflectionUtils.copy(job.getConfiguration(),
324 reader.getCurrentKey(), null));
325 ++kept;
326 }
327 }
328 reader.close();
329 }
330 return (K[])samples.toArray();
331 }
332 }
333
334
335
336
337
338
339
340 @SuppressWarnings("unchecked")
341 public static <K,V> void writePartitionFile(Job job, Sampler<K,V> sampler)
342 throws IOException, ClassNotFoundException, InterruptedException {
343 Configuration conf = job.getConfiguration();
344 final InputFormat inf =
345 ReflectionUtils.newInstance(job.getInputFormatClass(), conf);
346 int numPartitions = job.getNumReduceTasks();
347 K[] samples = sampler.getSample(inf, job);
348 LOG.info("Using " + samples.length + " samples");
349 RawComparator<K> comparator =
350 (RawComparator<K>) job.getSortComparator();
351 Arrays.sort(samples, comparator);
352 Path dst = new Path(TotalOrderPartitioner.getPartitionFile(conf));
353 FileSystem fs = dst.getFileSystem(conf);
354 if (fs.exists(dst)) {
355 fs.delete(dst, false);
356 }
357 SequenceFile.Writer writer = SequenceFile.createWriter(fs,
358 conf, dst, job.getMapOutputKeyClass(), NullWritable.class);
359 NullWritable nullValue = NullWritable.get();
360 float stepSize = samples.length / (float) numPartitions;
361 int last = -1;
362 for(int i = 1; i < numPartitions; ++i) {
363 int k = Math.round(stepSize * i);
364 while (last >= k && comparator.compare(samples[last], samples[k]) == 0) {
365 ++k;
366 }
367 writer.append(samples[k], nullValue);
368 last = k;
369 }
370 writer.close();
371 }
372
373
374
375
376
377 @Override
378 public int run(String[] args) throws Exception {
379 Job job = new Job(getConf());
380 ArrayList<String> otherArgs = new ArrayList<String>();
381 Sampler<K,V> sampler = null;
382 for(int i=0; i < args.length; ++i) {
383 try {
384 if ("-r".equals(args[i])) {
385 job.setNumReduceTasks(Integer.parseInt(args[++i]));
386 } else if ("-inFormat".equals(args[i])) {
387 job.setInputFormatClass(
388 Class.forName(args[++i]).asSubclass(InputFormat.class));
389 } else if ("-keyClass".equals(args[i])) {
390 job.setMapOutputKeyClass(
391 Class.forName(args[++i]).asSubclass(WritableComparable.class));
392 } else if ("-splitSample".equals(args[i])) {
393 int numSamples = Integer.parseInt(args[++i]);
394 int maxSplits = Integer.parseInt(args[++i]);
395 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
396 sampler = new SplitSampler<K,V>(numSamples, maxSplits);
397 } else if ("-splitRandom".equals(args[i])) {
398 double pcnt = Double.parseDouble(args[++i]);
399 int numSamples = Integer.parseInt(args[++i]);
400 int maxSplits = Integer.parseInt(args[++i]);
401 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
402 sampler = new RandomSampler<K,V>(pcnt, numSamples, maxSplits);
403 } else if ("-splitInterval".equals(args[i])) {
404 double pcnt = Double.parseDouble(args[++i]);
405 int maxSplits = Integer.parseInt(args[++i]);
406 if (0 >= maxSplits) maxSplits = Integer.MAX_VALUE;
407 sampler = new IntervalSampler<K,V>(pcnt, maxSplits);
408 } else {
409 otherArgs.add(args[i]);
410 }
411 } catch (NumberFormatException except) {
412 System.out.println("ERROR: Integer expected instead of " + args[i]);
413 return printUsage();
414 } catch (ArrayIndexOutOfBoundsException except) {
415 System.out.println("ERROR: Required parameter missing from " +
416 args[i-1]);
417 return printUsage();
418 }
419 }
420 if (job.getNumReduceTasks() <= 1) {
421 System.err.println("Sampler requires more than one reducer");
422 return printUsage();
423 }
424 if (otherArgs.size() < 2) {
425 System.out.println("ERROR: Wrong number of parameters: ");
426 return printUsage();
427 }
428 if (null == sampler) {
429 sampler = new RandomSampler<K,V>(0.1, 10000, 10);
430 }
431
432 Path outf = new Path(otherArgs.remove(otherArgs.size() - 1));
433 TotalOrderPartitioner.setPartitionFile(job.getConfiguration(), outf);
434 for (String s: otherArgs) {
435 FileInputFormat.addInputPath(job, new Path(s));
436 }
437 InputSampler.<K,V>writePartitionFile(job, sampler);
438
439 return 0;
440 }
441
442 public static void main(String[] args) throws Exception {
443 InputSampler<?,?> sampler = new InputSampler(new Configuration());
444 int res = ToolRunner.run(sampler, args);
445 System.exit(res);
446 }
447 }