App Engine Python SDK version 1.7.7
[gae.git] / python / google / appengine / ext / mapreduce / handlers.py
blobd434badc2a7c5883783f2b931f235b40dff3f546
1 #!/usr/bin/env python
3 # Copyright 2007 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
32 """Defines executor tasks handlers for MapReduce implementation."""
38 import datetime
39 import gc
40 import logging
41 import math
42 import os
43 import sys
44 import time
45 import traceback
47 from google.appengine.api import taskqueue
48 from google.appengine.ext import db
49 from google.appengine.ext.mapreduce import base_handler
50 from google.appengine.ext.mapreduce import context
51 from google.appengine.ext.mapreduce import errors
52 from google.appengine.ext.mapreduce import input_readers
53 from google.appengine.ext.mapreduce import model
54 from google.appengine.ext.mapreduce import operation
55 from google.appengine.ext.mapreduce import parameters
56 from google.appengine.ext.mapreduce import util
58 try:
59 from google.appengine.ext import ndb
60 except ImportError:
61 ndb = None
66 _SLICE_DURATION_SEC = 15
69 _CONTROLLER_PERIOD_SEC = 2
73 _RETRY_SLICE_ERROR_MAX_RETRIES = 10
76 _TEST_INJECTED_FAULTS = set()
79 def _run_task_hook(hooks, method, task, queue_name):
80 """Invokes hooks.method(task, queue_name).
82 Args:
83 hooks: A hooks.Hooks instance or None.
84 method: The name of the method to invoke on the hooks class e.g.
85 "enqueue_kickoff_task".
86 task: The taskqueue.Task to pass to the hook method.
87 queue_name: The name of the queue to pass to the hook method.
89 Returns:
90 True if the hooks.Hooks instance handled the method, False otherwise.
91 """
92 if hooks is not None:
93 try:
94 getattr(hooks, method)(task, queue_name)
95 except NotImplementedError:
97 return False
99 return True
100 return False
103 class MapperWorkerCallbackHandler(util.HugeTaskHandler):
104 """Callback handler for mapreduce worker task.
106 Request Parameters:
107 mapreduce_spec: MapreduceSpec of the mapreduce serialized to json.
108 shard_id: id of the shard.
109 slice_id: id of the slice.
112 def __init__(self, *args):
113 """Constructor."""
114 util.HugeTaskHandler.__init__(self, *args)
115 self._time = time.time
117 def handle(self):
118 """Handle request."""
119 tstate = model.TransientShardState.from_request(self.request)
120 spec = tstate.mapreduce_spec
121 self._start_time = self._time()
122 shard_id = tstate.shard_id
124 shard_state, control = db.get([
125 model.ShardState.get_key_by_shard_id(shard_id),
126 model.MapreduceControl.get_key_by_job_id(spec.mapreduce_id),
128 if not shard_state:
131 logging.error("State not found for shard %s; Possible spurious task "
132 "execution. Dropping this task.",
133 shard_id)
134 return
136 if not shard_state.active:
137 logging.error("Shard %s is not active. Possible spurious task "
138 "execution. Dropping this task.", shard_id)
139 logging.error(str(shard_state))
140 return
141 if shard_state.retries > tstate.retries:
142 logging.error(
143 "Got shard %s from previous shard retry %s. Possible spurious "
144 "task execution. Dropping this task.",
145 shard_id,
146 tstate.retries)
147 logging.error(str(shard_state))
148 return
149 elif shard_state.retries < tstate.retries:
154 raise ValueError(
155 "ShardState for %s is behind slice. Waiting for it to catch up",
156 shard_state.shard_id)
158 ctx = context.Context(spec, shard_state,
159 task_retry_count=self.task_retry_count())
161 if control and control.command == model.MapreduceControl.ABORT:
162 logging.info("Abort command received by shard %d of job '%s'",
163 shard_state.shard_number, shard_state.mapreduce_id)
166 shard_state.active = False
167 shard_state.result_status = model.ShardState.RESULT_ABORTED
168 shard_state.put(config=util.create_datastore_write_config(spec))
169 model.MapreduceControl.abort(spec.mapreduce_id)
170 return
172 input_reader = tstate.input_reader
179 if ndb is not None:
180 ndb_ctx = ndb.get_context()
181 ndb_ctx.set_cache_policy(lambda key: False)
182 ndb_ctx.set_memcache_policy(lambda key: False)
184 context.Context._set(ctx)
185 retry_shard = False
187 try:
188 self.process_inputs(
189 input_reader, shard_state, tstate, ctx)
191 if not shard_state.active:
194 if (shard_state.result_status == model.ShardState.RESULT_SUCCESS and
195 tstate.output_writer):
196 tstate.output_writer.finalize(ctx, shard_state)
198 except Exception, e:
199 retry_shard = self._retry_logic(e, shard_state, tstate, spec.mapreduce_id)
200 finally:
201 context.Context._set(None)
203 config = util.create_datastore_write_config(spec)
208 @db.transactional(retries=5)
209 def tx():
210 fresh_shard_state = db.get(
211 model.ShardState.get_key_by_shard_id(shard_id))
212 if not fresh_shard_state:
213 raise db.Rollback()
214 if (not fresh_shard_state.active or
215 "worker_active_state_collision" in _TEST_INJECTED_FAULTS):
216 logging.error("Shard %s is not active. Possible spurious task "
217 "execution. Dropping this task.", shard_id)
218 logging.error("Datastore's %s", str(fresh_shard_state))
219 logging.error("Slice's %s", str(shard_state))
220 return
221 fresh_shard_state.copy_from(shard_state)
222 fresh_shard_state.put(config=config)
223 if retry_shard:
224 self._schedule_slice(fresh_shard_state, tstate)
225 elif shard_state.active:
226 self.reschedule(fresh_shard_state, tstate)
227 tx()
229 gc.collect()
231 def process_inputs(self,
232 input_reader,
233 shard_state,
234 tstate,
235 ctx):
236 """Read inputs, process them, and write out outputs.
238 This is the core logic of MapReduce. It reads inputs from input reader,
239 invokes user specified mapper function, and writes output with
240 output writer. It also updates shard_state accordingly.
241 e.g. if shard processing is done, set shard_state.active to False.
243 If errors.FailJobError is caught, it will fail this MR job.
244 All other exceptions will be logged and raised to taskqueue for retry
245 until the number of retries exceeds a limit.
247 Args:
248 input_reader: input reader.
249 shard_state: shard state.
250 tstate: transient shard state.
251 ctx: mapreduce context.
253 processing_limit = self._processing_limit(tstate.mapreduce_spec)
254 if processing_limit == 0:
255 return
257 finished_shard = True
259 for entity in input_reader:
260 if isinstance(entity, db.Model):
261 shard_state.last_work_item = repr(entity.key())
262 elif ndb and isinstance(entity, ndb.Model):
263 shard_state.last_work_item = repr(entity.key)
264 else:
265 shard_state.last_work_item = repr(entity)[:100]
267 processing_limit -= 1
269 if not self.process_data(
270 entity, input_reader, ctx, tstate):
271 finished_shard = False
272 break
273 elif processing_limit == 0:
274 finished_shard = False
275 break
278 operation.counters.Increment(
279 context.COUNTER_MAPPER_WALLTIME_MS,
280 int((self._time() - self._start_time)*1000))(ctx)
281 ctx.flush()
283 if finished_shard:
284 shard_state.active = False
285 shard_state.result_status = model.ShardState.RESULT_SUCCESS
287 def process_data(self, data, input_reader, ctx, transient_shard_state):
288 """Process a single data piece.
290 Call mapper handler on the data.
292 Args:
293 data: a datum to process.
294 input_reader: input reader.
295 ctx: mapreduce context
296 transient_shard_state: transient shard state.
298 Returns:
299 True if scan should be continued, False if scan should be stopped.
301 if data is not input_readers.ALLOW_CHECKPOINT:
302 ctx.counters.increment(context.COUNTER_MAPPER_CALLS)
304 handler = ctx.mapreduce_spec.mapper.handler
305 if input_reader.expand_parameters:
306 result = handler(*data)
307 else:
308 result = handler(data)
310 if util.is_generator(handler):
311 for output in result:
312 if isinstance(output, operation.Operation):
313 output(ctx)
314 else:
315 output_writer = transient_shard_state.output_writer
316 if not output_writer:
317 logging.error(
318 "Handler yielded %s, but no output writer is set.", output)
319 else:
320 output_writer.write(output, ctx)
322 if self._time() - self._start_time > _SLICE_DURATION_SEC:
323 return False
324 return True
326 def _retry_logic(self, e, shard_state, tstate, mr_id):
327 """Handle retry for this slice.
329 Args:
330 e: the exception caught.
331 shard_state: model.ShardState for current shard.
332 tstate: model.TransientShardState for current shard.
333 mr_id: mapreduce id.
335 Returns:
336 model.MapReduceState if shard should be retried. False otherwise.
338 Raises:
339 the exception caught if slice should be retried.
341 logging.error("Shard %s got error.", shard_state.shard_id)
344 logging.error(traceback.format_exc())
347 if type(e) is errors.FailJobError:
348 logging.error("Got FailJobError. Shard %s failed permanently.",
349 shard_state.shard_id)
350 shard_state.active = False
351 shard_state.result_status = model.ShardState.RESULT_FAILED
352 return False
355 if type(e) in errors.SHARD_RETRY_ERRORS:
356 shard_retry = shard_state.retries
357 if shard_retry < parameters.DEFAULT_SHARD_RETRY_LIMIT:
358 if tstate.output_writer and (
359 not tstate.output_writer._can_be_retried(tstate)):
360 logging.error("Can not retry shard. Shard %s failed permanently.",
361 shard_state.shard_id)
362 shard_state.active = False
363 shard_state.result_status = model.ShardState.RESULT_FAILED
364 return False
366 shard_state.reset_for_retry()
367 logging.error("Shard %s will be retried for the %s time.",
368 shard_state.shard_id,
369 shard_state.retries)
370 output_writer = None
371 if tstate.output_writer:
372 mr_state = model.MapreduceState.get_by_job_id(mr_id)
373 output_writer = tstate.output_writer.create(
374 mr_state, shard_state)
375 tstate.reset_for_retry(output_writer)
376 return True
378 else:
379 slice_retry = self.task_retry_count()
380 if slice_retry < _RETRY_SLICE_ERROR_MAX_RETRIES:
381 logging.error(
382 "Will retry slice %s %s for the %s time.",
383 tstate.shard_id,
384 tstate.slice_id,
385 slice_retry + 1)
389 sys.exc_clear()
390 raise errors.RetrySliceError("Raise an error to trigger slice retry")
392 logging.error("Slice reached max retry limit of %s. "
393 "Shard %s failed permanently.",
394 self.task_retry_count(),
395 shard_state.shard_id)
396 shard_state.active = False
397 shard_state.result_status = model.ShardState.RESULT_FAILED
398 return False
400 @staticmethod
401 def get_task_name(shard_id, slice_id, retry=0):
402 """Compute single worker task name.
404 Args:
405 transient_shard_state: An instance of TransientShardState.
407 Returns:
408 task name which should be used to process specified shard/slice.
412 return "appengine-mrshard-%s-%s-retry-%s" % (
413 shard_id, slice_id, retry)
415 def reschedule(self, shard_state, tstate):
416 """Reschedule worker task to continue scanning work.
418 Args:
419 tstate: an instance of TransientShardState.
421 tstate.slice_id += 1
422 spec = tstate.mapreduce_spec
423 countdown = 0
424 if self._processing_limit(spec) != -1:
425 countdown = max(
426 int(_SLICE_DURATION_SEC - (self._time() - self._start_time)), 0)
427 MapperWorkerCallbackHandler._schedule_slice(
428 shard_state, tstate, countdown=countdown)
430 @classmethod
431 def _schedule_slice(cls,
432 shard_state,
433 transient_shard_state,
434 queue_name=None,
435 eta=None,
436 countdown=None):
437 """Schedule slice scanning by adding it to the task queue.
439 Args:
440 shard_state: An instance of ShardState.
441 transient_shard_state: An instance of TransientShardState.
442 queue_name: Optional queue to run on; uses the current queue of
443 execution or the default queue if unspecified.
444 eta: Absolute time when the MR should execute. May not be specified
445 if 'countdown' is also supplied. This may be timezone-aware or
446 timezone-naive.
447 countdown: Time in seconds into the future that this MR should execute.
448 Defaults to zero.
450 base_path = transient_shard_state.base_path
451 mapreduce_spec = transient_shard_state.mapreduce_spec
453 task_name = MapperWorkerCallbackHandler.get_task_name(
454 transient_shard_state.shard_id,
455 transient_shard_state.slice_id,
456 transient_shard_state.retries)
457 queue_name = queue_name or os.environ.get("HTTP_X_APPENGINE_QUEUENAME",
458 "default")
460 worker_task = util.HugeTask(url=base_path + "/worker_callback",
461 params=transient_shard_state.to_dict(),
462 name=task_name,
463 eta=eta,
464 countdown=countdown)
466 if not _run_task_hook(mapreduce_spec.get_hooks(),
467 "enqueue_worker_task",
468 worker_task,
469 queue_name):
470 try:
471 worker_task.add(queue_name, parent=shard_state)
472 except (taskqueue.TombstonedTaskError,
473 taskqueue.TaskAlreadyExistsError), e:
474 logging.warning("Task %r with params %r already exists. %s: %s",
475 task_name,
476 transient_shard_state.to_dict(),
477 e.__class__,
480 def _processing_limit(self, spec):
481 """Get the limit on the number of map calls allowed by this slice.
483 Args:
484 spec: a Mapreduce spec.
486 Returns:
487 The limit as a positive int if specified by user. -1 otherwise.
489 processing_rate = float(spec.mapper.params.get("processing_rate", 0))
490 slice_processing_limit = -1
491 if processing_rate > 0:
492 slice_processing_limit = int(math.ceil(
493 _SLICE_DURATION_SEC*processing_rate/int(spec.mapper.shard_count)))
494 return slice_processing_limit
497 class ControllerCallbackHandler(util.HugeTaskHandler):
498 """Supervises mapreduce execution.
500 Is also responsible for gathering execution status from shards together.
502 This task is "continuously" running by adding itself again to taskqueue if
503 mapreduce is still active.
506 def __init__(self, *args):
507 """Constructor."""
508 util.HugeTaskHandler.__init__(self, *args)
509 self._time = time.time
511 def handle(self):
512 """Handle request."""
513 spec = model.MapreduceSpec.from_json_str(
514 self.request.get("mapreduce_spec"))
516 state, control = db.get([
517 model.MapreduceState.get_key_by_job_id(spec.mapreduce_id),
518 model.MapreduceControl.get_key_by_job_id(spec.mapreduce_id),
520 if not state:
521 logging.error("State not found for mapreduce_id '%s'; skipping",
522 spec.mapreduce_id)
523 return
525 shard_states = model.ShardState.find_by_mapreduce_state(state)
526 if state.active and len(shard_states) != spec.mapper.shard_count:
528 logging.error("Incorrect number of shard states: %d vs %d; "
529 "aborting job '%s'",
530 len(shard_states), spec.mapper.shard_count,
531 spec.mapreduce_id)
532 state.active = False
533 state.result_status = model.MapreduceState.RESULT_FAILED
534 model.MapreduceControl.abort(spec.mapreduce_id)
536 active_shards = [s for s in shard_states if s.active]
537 failed_shards = [s for s in shard_states
538 if s.result_status == model.ShardState.RESULT_FAILED]
539 aborted_shards = [s for s in shard_states
540 if s.result_status == model.ShardState.RESULT_ABORTED]
541 if state.active:
542 state.active = bool(active_shards)
543 state.active_shards = len(active_shards)
544 state.failed_shards = len(failed_shards)
545 state.aborted_shards = len(aborted_shards)
546 if not control and failed_shards:
547 model.MapreduceControl.abort(spec.mapreduce_id)
549 if (not state.active and control and
550 control.command == model.MapreduceControl.ABORT):
552 logging.info("Abort signal received for job '%s'", spec.mapreduce_id)
553 state.result_status = model.MapreduceState.RESULT_ABORTED
555 if not state.active:
556 state.active_shards = 0
557 if not state.result_status:
559 if [s for s in shard_states
560 if s.result_status != model.ShardState.RESULT_SUCCESS]:
561 state.result_status = model.MapreduceState.RESULT_FAILED
562 else:
563 state.result_status = model.MapreduceState.RESULT_SUCCESS
564 logging.info("Final result for job '%s' is '%s'",
565 spec.mapreduce_id, state.result_status)
567 self.aggregate_state(state, shard_states)
568 state.last_poll_time = datetime.datetime.utcfromtimestamp(self._time())
570 if not state.active:
571 ControllerCallbackHandler._finalize_job(
572 spec, state, self.base_path())
573 return
577 config = util.create_datastore_write_config(spec)
578 state.put(config=config)
580 ControllerCallbackHandler.reschedule(
581 state, self.base_path(), spec, self.serial_id() + 1)
583 def aggregate_state(self, mapreduce_state, shard_states):
584 """Update current mapreduce state by aggregating shard states.
586 Args:
587 mapreduce_state: current mapreduce state as MapreduceState.
588 shard_states: all shard states (active and inactive). list of ShardState.
590 processed_counts = []
591 mapreduce_state.counters_map.clear()
593 for shard_state in shard_states:
594 mapreduce_state.counters_map.add_map(shard_state.counters_map)
595 processed_counts.append(shard_state.counters_map.get(
596 context.COUNTER_MAPPER_CALLS))
598 mapreduce_state.set_processed_counts(processed_counts)
600 def serial_id(self):
601 """Get serial unique identifier of this task from request.
603 Returns:
604 serial identifier as int.
606 return int(self.request.get("serial_id"))
608 @staticmethod
609 def _finalize_job(mapreduce_spec, mapreduce_state, base_path):
610 """Finalize job execution.
612 Finalizes output writer, invokes done callback an schedules
613 finalize job execution.
615 Args:
616 mapreduce_spec: an instance of MapreduceSpec
617 mapreduce_state: an instance of MapreduceState
618 base_path: handler base path.
620 config = util.create_datastore_write_config(mapreduce_spec)
623 if (mapreduce_spec.mapper.output_writer_class() and
624 mapreduce_state.result_status == model.MapreduceState.RESULT_SUCCESS):
625 mapreduce_spec.mapper.output_writer_class().finalize_job(mapreduce_state)
628 def put_state(state):
629 state.put(config=config)
630 done_callback = mapreduce_spec.params.get(
631 model.MapreduceSpec.PARAM_DONE_CALLBACK)
632 if done_callback:
633 done_task = taskqueue.Task(
634 url=done_callback,
635 headers={"Mapreduce-Id": mapreduce_spec.mapreduce_id},
636 method=mapreduce_spec.params.get("done_callback_method", "POST"))
637 queue_name = mapreduce_spec.params.get(
638 model.MapreduceSpec.PARAM_DONE_CALLBACK_QUEUE,
639 "default")
641 if not _run_task_hook(mapreduce_spec.get_hooks(),
642 "enqueue_done_task",
643 done_task,
644 queue_name):
645 done_task.add(queue_name, transactional=True)
646 FinalizeJobHandler.schedule(base_path, mapreduce_spec)
648 db.run_in_transaction_custom_retries(5, put_state, mapreduce_state)
650 @staticmethod
651 def get_task_name(mapreduce_spec, serial_id):
652 """Compute single controller task name.
654 Args:
655 transient_shard_state: an instance of TransientShardState.
657 Returns:
658 task name which should be used to process specified shard/slice.
662 return "appengine-mrcontrol-%s-%s" % (
663 mapreduce_spec.mapreduce_id, serial_id)
665 @staticmethod
666 def controller_parameters(mapreduce_spec, serial_id):
667 """Fill in controller task parameters.
669 Returned parameters map is to be used as task payload, and it contains
670 all the data, required by controller to perform its function.
672 Args:
673 mapreduce_spec: specification of the mapreduce.
674 serial_id: id of the invocation as int.
676 Returns:
677 string->string map of parameters to be used as task payload.
679 return {"mapreduce_spec": mapreduce_spec.to_json_str(),
680 "serial_id": str(serial_id)}
682 @classmethod
683 def reschedule(cls,
684 mapreduce_state,
685 base_path,
686 mapreduce_spec,
687 serial_id,
688 queue_name=None):
689 """Schedule new update status callback task.
691 Args:
692 mapreduce_state: mapreduce state as model.MapreduceState
693 base_path: mapreduce handlers url base path as string.
694 mapreduce_spec: mapreduce specification as MapreduceSpec.
695 serial_id: id of the invocation as int.
696 queue_name: The queue to schedule this task on. Will use the current
697 queue of execution if not supplied.
699 task_name = ControllerCallbackHandler.get_task_name(
700 mapreduce_spec, serial_id)
701 task_params = ControllerCallbackHandler.controller_parameters(
702 mapreduce_spec, serial_id)
703 if not queue_name:
704 queue_name = os.environ.get("HTTP_X_APPENGINE_QUEUENAME", "default")
706 controller_callback_task = util.HugeTask(
707 url=base_path + "/controller_callback",
708 name=task_name, params=task_params,
709 countdown=_CONTROLLER_PERIOD_SEC)
711 if not _run_task_hook(mapreduce_spec.get_hooks(),
712 "enqueue_controller_task",
713 controller_callback_task,
714 queue_name):
715 try:
716 controller_callback_task.add(queue_name, parent=mapreduce_state)
717 except (taskqueue.TombstonedTaskError,
718 taskqueue.TaskAlreadyExistsError), e:
719 logging.warning("Task %r with params %r already exists. %s: %s",
720 task_name, task_params, e.__class__, e)
723 class KickOffJobHandler(util.HugeTaskHandler):
724 """Taskqueue handler which kicks off a mapreduce processing.
726 Request Parameters:
727 mapreduce_spec: MapreduceSpec of the mapreduce serialized to json.
728 input_readers: List of InputReaders objects separated by semi-colons.
731 def handle(self):
732 """Handles kick off request."""
733 spec = model.MapreduceSpec.from_json_str(
734 self._get_required_param("mapreduce_spec"))
736 app_id = self.request.get("app", None)
737 queue_name = os.environ.get("HTTP_X_APPENGINE_QUEUENAME", "default")
738 mapper_input_reader_class = spec.mapper.input_reader_class()
742 state = model.MapreduceState.create_new(spec.mapreduce_id)
743 state.mapreduce_spec = spec
744 state.active = True
745 if app_id:
746 state.app_id = app_id
748 input_readers = mapper_input_reader_class.split_input(spec.mapper)
749 if not input_readers:
751 logging.warning("Found no mapper input data to process.")
752 state.active = False
753 state.active_shards = 0
754 ControllerCallbackHandler._finalize_job(spec, state, self.base_path())
755 return
758 spec.mapper.shard_count = len(input_readers)
759 state.active_shards = len(input_readers)
760 state.mapreduce_spec = spec
762 output_writer_class = spec.mapper.output_writer_class()
763 if output_writer_class:
764 output_writer_class.init_job(state)
766 state.put(config=util.create_datastore_write_config(spec))
768 KickOffJobHandler._schedule_shards(
769 spec, input_readers, queue_name, self.base_path(), state)
771 ControllerCallbackHandler.reschedule(
772 state, self.base_path(), spec, queue_name=queue_name, serial_id=0)
774 def _get_required_param(self, param_name):
775 """Get a required request parameter.
777 Args:
778 param_name: name of request parameter to fetch.
780 Returns:
781 parameter value
783 Raises:
784 errors.NotEnoughArgumentsError: if parameter is not specified.
786 value = self.request.get(param_name)
787 if not value:
788 raise errors.NotEnoughArgumentsError(param_name + " not specified")
789 return value
791 @classmethod
792 def _schedule_shards(cls,
793 spec,
794 input_readers,
795 queue_name,
796 base_path,
797 mr_state):
798 """Prepares shard states and schedules their execution.
800 Args:
801 spec: mapreduce specification as MapreduceSpec.
802 input_readers: list of InputReaders describing shard splits.
803 queue_name: The queue to run this job on.
804 base_path: The base url path of mapreduce callbacks.
805 mr_state: The MapReduceState of current job.
810 shard_states = []
811 writer_class = spec.mapper.output_writer_class()
812 output_writers = [None] * len(input_readers)
813 for shard_number, input_reader in enumerate(input_readers):
814 shard_state = model.ShardState.create_new(spec.mapreduce_id, shard_number)
815 shard_state.shard_description = str(input_reader)
816 if writer_class:
817 output_writers[shard_number] = writer_class.create(
818 mr_state, shard_state)
819 shard_states.append(shard_state)
822 existing_shard_states = db.get(shard.key() for shard in shard_states)
823 existing_shard_keys = set(shard.key() for shard in existing_shard_states
824 if shard is not None)
827 db.put((shard for shard in shard_states
828 if shard.key() not in existing_shard_keys),
829 config=util.create_datastore_write_config(spec))
832 for shard_number, (input_reader, output_writer) in enumerate(
833 zip(input_readers, output_writers)):
834 shard_id = model.ShardState.shard_id_from_number(
835 spec.mapreduce_id, shard_number)
836 MapperWorkerCallbackHandler._schedule_slice(
837 shard_states[shard_number],
838 model.TransientShardState(
839 base_path, spec, shard_id, 0, input_reader, input_reader,
840 output_writer=output_writer),
841 queue_name=queue_name)
844 class StartJobHandler(base_handler.PostJsonHandler):
845 """Command handler starts a mapreduce job."""
847 def handle(self):
848 """Handles start request."""
850 mapreduce_name = self._get_required_param("name")
851 mapper_input_reader_spec = self._get_required_param("mapper_input_reader")
852 mapper_handler_spec = self._get_required_param("mapper_handler")
853 mapper_output_writer_spec = self.request.get("mapper_output_writer")
854 mapper_params = self._get_params(
855 "mapper_params_validator", "mapper_params.")
856 params = self._get_params(
857 "params_validator", "params.")
860 mapper_params["processing_rate"] = int(mapper_params.get(
861 "processing_rate") or model._DEFAULT_PROCESSING_RATE_PER_SEC)
862 queue_name = mapper_params["queue_name"] = mapper_params.get(
863 "queue_name", "default")
866 mapper_spec = model.MapperSpec(
867 mapper_handler_spec,
868 mapper_input_reader_spec,
869 mapper_params,
870 int(mapper_params.get("shard_count", model._DEFAULT_SHARD_COUNT)),
871 output_writer_spec=mapper_output_writer_spec)
873 mapreduce_id = type(self)._start_map(
874 mapreduce_name,
875 mapper_spec,
876 params,
877 base_path=self.base_path(),
878 queue_name=queue_name,
879 _app=mapper_params.get("_app"))
880 self.json_response["mapreduce_id"] = mapreduce_id
882 def _get_params(self, validator_parameter, name_prefix):
883 """Retrieves additional user-supplied params for the job and validates them.
885 Args:
886 validator_parameter: name of the request parameter which supplies
887 validator for this parameter set.
888 name_prefix: common prefix for all parameter names in the request.
890 Raises:
891 Any exception raised by the 'params_validator' request parameter if
892 the params fail to validate.
894 params_validator = self.request.get(validator_parameter)
896 user_params = {}
897 for key in self.request.arguments():
898 if key.startswith(name_prefix):
899 values = self.request.get_all(key)
900 adjusted_key = key[len(name_prefix):]
901 if len(values) == 1:
902 user_params[adjusted_key] = values[0]
903 else:
904 user_params[adjusted_key] = values
906 if params_validator:
907 resolved_validator = util.for_name(params_validator)
908 resolved_validator(user_params)
910 return user_params
912 def _get_required_param(self, param_name):
913 """Get a required request parameter.
915 Args:
916 param_name: name of request parameter to fetch.
918 Returns:
919 parameter value
921 Raises:
922 errors.NotEnoughArgumentsError: if parameter is not specified.
924 value = self.request.get(param_name)
925 if not value:
926 raise errors.NotEnoughArgumentsError(param_name + " not specified")
927 return value
929 @classmethod
930 def _start_map(cls,
931 name,
932 mapper_spec,
933 mapreduce_params,
934 base_path=None,
935 queue_name=None,
936 eta=None,
937 countdown=None,
938 hooks_class_name=None,
939 _app=None,
940 transactional=False,
941 parent_entity=None):
942 queue_name = queue_name or os.environ.get("HTTP_X_APPENGINE_QUEUENAME",
943 "default")
944 if queue_name[0] == "_":
946 queue_name = "default"
948 if not transactional and parent_entity:
949 raise Exception("Parent shouldn't be specfied "
950 "for non-transactional starts.")
953 mapper_input_reader_class = mapper_spec.input_reader_class()
954 mapper_input_reader_class.validate(mapper_spec)
956 mapper_output_writer_class = mapper_spec.output_writer_class()
957 if mapper_output_writer_class:
958 mapper_output_writer_class.validate(mapper_spec)
960 mapreduce_id = model.MapreduceState.new_mapreduce_id()
961 mapreduce_spec = model.MapreduceSpec(
962 name,
963 mapreduce_id,
964 mapper_spec.to_json(),
965 mapreduce_params,
966 hooks_class_name)
969 ctx = context.Context(mapreduce_spec, None)
970 context.Context._set(ctx)
971 try:
972 mapper_spec.get_handler()
973 finally:
974 context.Context._set(None)
976 kickoff_params = {"mapreduce_spec": mapreduce_spec.to_json_str()}
977 if _app:
978 kickoff_params["app"] = _app
979 kickoff_worker_task = util.HugeTask(
980 url=base_path + "/kickoffjob_callback",
981 params=kickoff_params,
982 eta=eta,
983 countdown=countdown)
985 hooks = mapreduce_spec.get_hooks()
986 config = util.create_datastore_write_config(mapreduce_spec)
988 def start_mapreduce():
989 parent = parent_entity
990 if not transactional:
994 state = model.MapreduceState.create_new(mapreduce_spec.mapreduce_id)
995 state.mapreduce_spec = mapreduce_spec
996 state.active = True
997 state.active_shards = mapper_spec.shard_count
998 if _app:
999 state.app_id = _app
1000 state.put(config=config)
1001 parent = state
1003 if hooks is not None:
1004 try:
1005 hooks.enqueue_kickoff_task(kickoff_worker_task, queue_name)
1006 except NotImplementedError:
1008 pass
1009 else:
1010 return
1011 kickoff_worker_task.add(queue_name, transactional=True, parent=parent)
1013 if transactional:
1014 start_mapreduce()
1015 else:
1016 db.run_in_transaction(start_mapreduce)
1018 return mapreduce_id
1021 class FinalizeJobHandler(base_handler.TaskQueueHandler):
1022 """Finalize map job by deleting all temporary entities."""
1024 def handle(self):
1025 mapreduce_id = self.request.get("mapreduce_id")
1026 mapreduce_state = model.MapreduceState.get_by_job_id(mapreduce_id)
1027 if mapreduce_state:
1028 config=util.create_datastore_write_config(mapreduce_state.mapreduce_spec)
1029 db.delete(model.MapreduceControl.get_key_by_job_id(mapreduce_id),
1030 config=config)
1031 shard_states = model.ShardState.find_by_mapreduce_state(mapreduce_state)
1032 for shard_state in shard_states:
1033 db.delete(util._HugeTaskPayload.all().ancestor(shard_state),
1034 config=config)
1035 db.delete(shard_states, config=config)
1036 db.delete(util._HugeTaskPayload.all().ancestor(mapreduce_state),
1037 config=config)
1039 @classmethod
1040 def schedule(cls, base_path, mapreduce_spec):
1041 """Schedule finalize task.
1043 Args:
1044 mapreduce_spec: mapreduce specification as MapreduceSpec.
1046 task_name = mapreduce_spec.mapreduce_id + "-finalize"
1047 finalize_task = taskqueue.Task(
1048 name=task_name,
1049 url=base_path + "/finalizejob_callback",
1050 params={"mapreduce_id": mapreduce_spec.mapreduce_id})
1051 queue_name = os.environ.get("HTTP_X_APPENGINE_QUEUENAME", "default")
1052 if not _run_task_hook(mapreduce_spec.get_hooks(),
1053 "enqueue_controller_task",
1054 finalize_task,
1055 queue_name):
1056 try:
1057 finalize_task.add(queue_name)
1058 except (taskqueue.TombstonedTaskError,
1059 taskqueue.TaskAlreadyExistsError), e:
1060 logging.warning("Task %r already exists. %s: %s",
1061 task_name, e.__class__, e)
1064 class CleanUpJobHandler(base_handler.PostJsonHandler):
1065 """Command to kick off tasks to clean up a job's data."""
1067 def handle(self):
1068 mapreduce_id = self.request.get("mapreduce_id")
1069 db.delete(model.MapreduceState.get_key_by_job_id(mapreduce_id))
1070 self.json_response["status"] = ("Job %s successfully cleaned up." %
1071 mapreduce_id)
1074 class AbortJobHandler(base_handler.PostJsonHandler):
1075 """Command to abort a running job."""
1077 def handle(self):
1078 model.MapreduceControl.abort(self.request.get("mapreduce_id"))
1079 self.json_response["status"] = "Abort signal sent."