Imported Upstream version 1.72.0
[platform/upstream/boost.git] / boost / mpi / detail / request_handlers.hpp
1 // Copyright (C) 2018 Alain Miniussi <alain.miniussi@oca.eu>.
2
3 // Use, modification and distribution is subject to the Boost Software
4 // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
5 // http://www.boost.org/LICENSE_1_0.txt)
6
7 // Request implementation dtails
8
9 // This header should be included only after the communicator and request 
10 // classes has been defined.
11 #ifndef BOOST_MPI_REQUEST_HANDLERS_HPP
12 #define BOOST_MPI_REQUEST_HANDLERS_HPP
13
14 #include <boost/mpi/skeleton_and_content_types.hpp>
15
16 namespace boost { namespace mpi {
17
18 namespace detail {
19 /**
20  * Internal data structure that stores everything required to manage
21  * the receipt of serialized data via a request object.
22  */
23 template<typename T>
24 struct serialized_irecv_data {
25   serialized_irecv_data(const communicator& comm, T& value)
26     : m_ia(comm), m_value(value) {}
27
28   void deserialize(status& stat) 
29   { 
30     m_ia >> m_value; 
31     stat.m_count = 1;
32   }
33
34   std::size_t     m_count;
35   packed_iarchive m_ia;
36   T&              m_value;
37 };
38
39 template<>
40 struct serialized_irecv_data<packed_iarchive>
41 {
42   serialized_irecv_data(communicator const&, packed_iarchive& ia) : m_ia(ia) { }
43
44   void deserialize(status&) { /* Do nothing. */ }
45
46   std::size_t      m_count;
47   packed_iarchive& m_ia;
48 };
49
50 /**
51  * Internal data structure that stores everything required to manage
52  * the receipt of an array of serialized data via a request object.
53  */
54 template<typename T>
55 struct serialized_array_irecv_data
56 {
57   serialized_array_irecv_data(const communicator& comm, T* values, int n)
58     : m_count(0), m_ia(comm), m_values(values), m_nb(n) {}
59
60   void deserialize(status& stat);
61
62   std::size_t     m_count;
63   packed_iarchive m_ia;
64   T*              m_values;
65   int             m_nb;
66 };
67
68 template<typename T>
69 void serialized_array_irecv_data<T>::deserialize(status& stat)
70 {
71   T* v = m_values;
72   T* end =  m_values+m_nb;
73   while (v < end) {
74     m_ia >> *v++;
75   }
76   stat.m_count = m_nb;
77 }
78
79 /**
80  * Internal data structure that stores everything required to manage
81  * the receipt of an array of primitive data but unknown size.
82  * Such an array can have been send with blocking operation and so must
83  * be compatible with the (size_t,raw_data[]) format.
84  */
85 template<typename T, class A>
86 struct dynamic_array_irecv_data
87 {
88   BOOST_STATIC_ASSERT_MSG(is_mpi_datatype<T>::value, "Can only be specialized for MPI datatypes.");
89
90   dynamic_array_irecv_data(std::vector<T,A>& values)
91     : m_count(-1), m_values(values) {}
92
93   std::size_t       m_count;
94   std::vector<T,A>& m_values;
95 };
96
97 template<typename T>
98 struct serialized_irecv_data<const skeleton_proxy<T> >
99 {
100   serialized_irecv_data(const communicator& comm, skeleton_proxy<T> proxy)
101     : m_isa(comm), m_ia(m_isa.get_skeleton()), m_proxy(proxy) { }
102
103   void deserialize(status& stat) 
104   { 
105     m_isa >> m_proxy.object;
106     stat.m_count = 1;
107   }
108
109   std::size_t              m_count;
110   packed_skeleton_iarchive m_isa;
111   packed_iarchive&         m_ia;
112   skeleton_proxy<T>        m_proxy;
113 };
114
115 template<typename T>
116 struct serialized_irecv_data<skeleton_proxy<T> >
117   : public serialized_irecv_data<const skeleton_proxy<T> >
118 {
119   typedef serialized_irecv_data<const skeleton_proxy<T> > inherited;
120
121   serialized_irecv_data(const communicator& comm, const skeleton_proxy<T>& proxy)
122     : inherited(comm, proxy) { }
123 };
124 }
125
126 #if BOOST_MPI_VERSION >= 3
127 template<class Data>
128 class request::probe_handler
129   : public request::handler,
130     protected Data {
131
132 protected:
133   template<typename I1>
134   probe_handler(communicator const& comm, int source, int tag, I1& i1)
135     : Data(comm, i1),
136       m_comm(comm),
137       m_source(source),
138       m_tag(tag) {}
139   // no variadic template for now
140   template<typename I1, typename I2>
141   probe_handler(communicator const& comm, int source, int tag, I1& i1, I2& i2)
142     : Data(comm, i1, i2),
143       m_comm(comm),
144       m_source(source),
145       m_tag(tag) {}
146
147 public:
148   bool active() const { return m_source != MPI_PROC_NULL; }
149   optional<MPI_Request&> trivial() { return boost::none; }
150   void cancel() { m_source = MPI_PROC_NULL; }
151
152   status wait() {
153     MPI_Message msg;
154     status stat;
155     BOOST_MPI_CHECK_RESULT(MPI_Mprobe, (m_source,m_tag,m_comm,&msg,&stat.m_status));
156     return unpack(msg, stat);
157   }
158   
159   optional<status> test() {
160     status stat;
161     int flag = 0;
162     MPI_Message msg;
163     BOOST_MPI_CHECK_RESULT(MPI_Improbe, (m_source,m_tag,m_comm,&flag,&msg,&stat.m_status));
164     if (flag) {
165       return unpack(msg, stat);
166     } else {
167       return optional<status>();
168     } 
169   }
170
171 protected:
172   friend class request;
173
174   status unpack(MPI_Message& msg, status& stat) {
175     int count;
176     MPI_Datatype datatype = this->Data::datatype();
177     BOOST_MPI_CHECK_RESULT(MPI_Get_count, (&stat.m_status, datatype, &count));
178     this->Data::resize(count);
179     BOOST_MPI_CHECK_RESULT(MPI_Mrecv, (this->Data::buffer(), count, datatype, &msg, &stat.m_status));
180     this->Data::deserialize();
181     m_source = MPI_PROC_NULL;
182     stat.m_count = 1;
183     return stat;
184   }
185   
186   communicator const& m_comm;
187   int m_source;
188   int m_tag;
189 };
190 #endif // BOOST_MPI_VERSION >= 3
191
192 namespace detail {
193 template<class A>
194 struct dynamic_primitive_array_data {
195   dynamic_primitive_array_data(communicator const&, A& arr) : m_buffer(arr) {}
196   
197   void* buffer() { return m_buffer.data(); }
198   void  resize(std::size_t sz) { m_buffer.resize(sz); }
199   void  deserialize() {}
200   MPI_Datatype datatype() { return get_mpi_datatype<typename A::value_type>(); }
201   
202   A& m_buffer;
203 };
204
205 template<typename T>
206 struct serialized_data {
207   serialized_data(communicator const& comm, T& value) : m_archive(comm), m_value(value) {}
208
209   void* buffer() { return m_archive.address(); }
210   void  resize(std::size_t sz) { m_archive.resize(sz); }
211   void  deserialize() { m_archive >> m_value; }
212   MPI_Datatype datatype() { return MPI_PACKED; }
213
214   packed_iarchive m_archive;
215   T& m_value;
216 };
217
218 template<>
219 struct serialized_data<packed_iarchive> {
220   serialized_data(communicator const& comm, packed_iarchive& ar) : m_archive(ar) {}
221   
222   void* buffer() { return m_archive.address(); }
223   void  resize(std::size_t sz) { m_archive.resize(sz); }
224   void  deserialize() {}
225   MPI_Datatype datatype() { return MPI_PACKED; }
226
227   packed_iarchive& m_archive;
228 };
229
230 template<typename T>
231 struct serialized_data<const skeleton_proxy<T> > {
232   serialized_data(communicator const& comm, skeleton_proxy<T> skel)
233     : m_proxy(skel),
234       m_archive(comm) {}
235   
236   void* buffer() { return m_archive.get_skeleton().address(); }
237   void  resize(std::size_t sz) { m_archive.get_skeleton().resize(sz); }
238   void  deserialize() { m_archive >> m_proxy.object; }
239   MPI_Datatype datatype() { return MPI_PACKED; }
240
241   skeleton_proxy<T> m_proxy;
242   packed_skeleton_iarchive m_archive;
243 };
244
245 template<typename T>
246 struct serialized_data<skeleton_proxy<T> >
247   : public serialized_data<const skeleton_proxy<T> > {
248   typedef serialized_data<const skeleton_proxy<T> > super;
249   serialized_data(communicator const& comm, skeleton_proxy<T> skel)
250     : super(comm, skel) {}
251 };
252
253 template<typename T>
254 struct serialized_array_data {
255   serialized_array_data(communicator const& comm, T* values, int nb)
256     : m_archive(comm), m_values(values), m_nb(nb) {}
257
258   void* buffer() { return m_archive.address(); }
259   void  resize(std::size_t sz) { m_archive.resize(sz); }
260   void  deserialize() {
261     T* end = m_values + m_nb;
262     T* v = m_values;
263     while (v != end) {
264       m_archive >> *v++;
265     }
266   }
267   MPI_Datatype datatype() { return MPI_PACKED; }
268
269   packed_iarchive m_archive;
270   T*  m_values;
271   int m_nb;
272 };
273
274 }
275
276 class BOOST_MPI_DECL request::legacy_handler : public request::handler {
277 public:
278   legacy_handler(communicator const& comm, int source, int tag);
279   
280   void cancel() {
281     for (int i = 0; i < 2; ++i) {
282       if (m_requests[i] != MPI_REQUEST_NULL) {
283         BOOST_MPI_CHECK_RESULT(MPI_Cancel, (m_requests+i));
284       }
285     }
286   }
287   
288   bool active() const;
289   optional<MPI_Request&> trivial();
290   
291   MPI_Request      m_requests[2];
292   communicator     m_comm;
293   int              m_source;
294   int              m_tag;
295 };
296
297 template<typename T>
298 class request::legacy_serialized_handler 
299   : public request::legacy_handler, 
300     protected detail::serialized_irecv_data<T> {
301 public:
302   typedef detail::serialized_irecv_data<T> extra;
303   legacy_serialized_handler(communicator const& comm, int source, int tag, T& value)
304     : legacy_handler(comm, source, tag),
305       extra(comm, value)  {
306     BOOST_MPI_CHECK_RESULT(MPI_Irecv,
307                            (&this->extra::m_count, 1, 
308                             get_mpi_datatype(this->extra::m_count),
309                             source, tag, comm, m_requests+0));
310     
311   }
312
313   status wait() {
314     status stat;
315     if (m_requests[1] == MPI_REQUEST_NULL) {
316       // Wait for the count message to complete
317       BOOST_MPI_CHECK_RESULT(MPI_Wait,
318                              (m_requests, &stat.m_status));
319       // Resize our buffer and get ready to receive its data
320       this->extra::m_ia.resize(this->extra::m_count);
321       BOOST_MPI_CHECK_RESULT(MPI_Irecv,
322                              (this->extra::m_ia.address(), this->extra::m_ia.size(), MPI_PACKED,
323                               stat.source(), stat.tag(), 
324                               MPI_Comm(m_comm), m_requests + 1));
325     }
326
327     // Wait until we have received the entire message
328     BOOST_MPI_CHECK_RESULT(MPI_Wait,
329                            (m_requests + 1, &stat.m_status));
330
331     this->deserialize(stat);
332     return stat;    
333   }
334   
335   optional<status> test() {
336     status stat;
337     int flag = 0;
338     
339     if (m_requests[1] == MPI_REQUEST_NULL) {
340       // Check if the count message has completed
341       BOOST_MPI_CHECK_RESULT(MPI_Test,
342                              (m_requests, &flag, &stat.m_status));
343       if (flag) {
344         // Resize our buffer and get ready to receive its data
345         this->extra::m_ia.resize(this->extra::m_count);
346         BOOST_MPI_CHECK_RESULT(MPI_Irecv,
347                                (this->extra::m_ia.address(), this->extra::m_ia.size(),MPI_PACKED,
348                                 stat.source(), stat.tag(), 
349                                 MPI_Comm(m_comm), m_requests + 1));
350       } else
351         return optional<status>(); // We have not finished yet
352     } 
353
354     // Check if we have received the message data
355     BOOST_MPI_CHECK_RESULT(MPI_Test,
356                            (m_requests + 1, &flag, &stat.m_status));
357     if (flag) {
358       this->deserialize(stat);
359       return stat;
360     } else 
361       return optional<status>();
362   }
363 };
364
365 template<typename T>
366 class request::legacy_serialized_array_handler 
367   : public    request::legacy_handler,
368     protected detail::serialized_array_irecv_data<T> {
369   typedef detail::serialized_array_irecv_data<T> extra;
370
371 public:
372   legacy_serialized_array_handler(communicator const& comm, int source, int tag, T* values, int n)
373     : legacy_handler(comm, source, tag),
374       extra(comm, values, n) {
375     BOOST_MPI_CHECK_RESULT(MPI_Irecv,
376                            (&this->extra::m_count, 1, 
377                             get_mpi_datatype(this->extra::m_count),
378                             source, tag, comm, m_requests+0));
379   }
380
381   status wait() {
382     status stat;
383     if (m_requests[1] == MPI_REQUEST_NULL) {
384       // Wait for the count message to complete
385       BOOST_MPI_CHECK_RESULT(MPI_Wait,
386                              (m_requests, &stat.m_status));
387       // Resize our buffer and get ready to receive its data
388       this->extra::m_ia.resize(this->extra::m_count);
389       BOOST_MPI_CHECK_RESULT(MPI_Irecv,
390                              (this->extra::m_ia.address(), this->extra::m_ia.size(), MPI_PACKED,
391                               stat.source(), stat.tag(), 
392                               MPI_Comm(m_comm), m_requests + 1));
393     }
394
395     // Wait until we have received the entire message
396     BOOST_MPI_CHECK_RESULT(MPI_Wait,
397                            (m_requests + 1, &stat.m_status));
398
399     this->deserialize(stat);
400     return stat;
401   }
402   
403   optional<status> test() {
404     status stat;
405     int flag = 0;
406     
407     if (m_requests[1] == MPI_REQUEST_NULL) {
408       // Check if the count message has completed
409       BOOST_MPI_CHECK_RESULT(MPI_Test,
410                              (m_requests, &flag, &stat.m_status));
411       if (flag) {
412         // Resize our buffer and get ready to receive its data
413         this->extra::m_ia.resize(this->extra::m_count);
414         BOOST_MPI_CHECK_RESULT(MPI_Irecv,
415                                (this->extra::m_ia.address(), this->extra::m_ia.size(),MPI_PACKED,
416                                 stat.source(), stat.tag(), 
417                                 MPI_Comm(m_comm), m_requests + 1));
418       } else
419         return optional<status>(); // We have not finished yet
420     } 
421
422     // Check if we have received the message data
423     BOOST_MPI_CHECK_RESULT(MPI_Test,
424                            (m_requests + 1, &flag, &stat.m_status));
425     if (flag) {
426       this->deserialize(stat);
427       return stat;
428     } else 
429       return optional<status>();
430   }
431 };
432
433 template<typename T, class A>
434 class request::legacy_dynamic_primitive_array_handler 
435   : public request::legacy_handler,
436     protected detail::dynamic_array_irecv_data<T,A>
437 {
438   typedef detail::dynamic_array_irecv_data<T,A> extra;
439
440 public:
441   legacy_dynamic_primitive_array_handler(communicator const& comm, int source, int tag, std::vector<T,A>& values)
442     : legacy_handler(comm, source, tag),
443       extra(values) {
444     BOOST_MPI_CHECK_RESULT(MPI_Irecv,
445                            (&this->extra::m_count, 1, 
446                             get_mpi_datatype(this->extra::m_count),
447                             source, tag, comm, m_requests+0));
448   }
449
450   status wait() {
451     status stat;
452     if (m_requests[1] == MPI_REQUEST_NULL) {
453       // Wait for the count message to complete
454       BOOST_MPI_CHECK_RESULT(MPI_Wait,
455                              (m_requests, &stat.m_status));
456       // Resize our buffer and get ready to receive its data
457       this->extra::m_values.resize(this->extra::m_count);
458       BOOST_MPI_CHECK_RESULT(MPI_Irecv,
459                              (&(this->extra::m_values[0]), this->extra::m_values.size(), get_mpi_datatype<T>(),
460                               stat.source(), stat.tag(), 
461                               MPI_Comm(m_comm), m_requests + 1));
462     }
463     // Wait until we have received the entire message
464     BOOST_MPI_CHECK_RESULT(MPI_Wait,
465                            (m_requests + 1, &stat.m_status));
466     return stat;    
467   }
468
469   optional<status> test() {
470     status stat;
471     int flag = 0;
472     
473     if (m_requests[1] == MPI_REQUEST_NULL) {
474       // Check if the count message has completed
475       BOOST_MPI_CHECK_RESULT(MPI_Test,
476                              (m_requests, &flag, &stat.m_status));
477       if (flag) {
478         // Resize our buffer and get ready to receive its data
479         this->extra::m_values.resize(this->extra::m_count);
480         BOOST_MPI_CHECK_RESULT(MPI_Irecv,
481                                (&(this->extra::m_values[0]), this->extra::m_values.size(), get_mpi_datatype<T>(),
482                                 stat.source(), stat.tag(), 
483                                 MPI_Comm(m_comm), m_requests + 1));
484       } else
485         return optional<status>(); // We have not finished yet
486     } 
487
488     // Check if we have received the message data
489     BOOST_MPI_CHECK_RESULT(MPI_Test,
490                            (m_requests + 1, &flag, &stat.m_status));
491     if (flag) {
492       return stat;
493     } else 
494       return optional<status>();
495   }
496 };
497
498 class BOOST_MPI_DECL request::trivial_handler : public request::handler {
499
500 public:
501   trivial_handler();
502   
503   status wait();
504   optional<status> test();
505   void cancel();
506   
507   bool active() const;
508   optional<MPI_Request&> trivial();
509
510 private:
511   friend class request;
512   MPI_Request      m_request;
513 };
514
515 class request::dynamic_handler : public request::handler {
516   dynamic_handler();
517   
518   status wait();
519   optional<status> test();
520   void cancel();
521   
522   bool active() const;
523   optional<MPI_Request&> trivial();
524
525 private:
526   friend class request;
527   MPI_Request      m_requests[2];
528 };
529
530 template<typename T> 
531 request request::make_serialized(communicator const& comm, int source, int tag, T& value) {
532 #if defined(BOOST_MPI_USE_IMPROBE)
533   return request(new probe_handler<detail::serialized_data<T> >(comm, source, tag, value));
534 #else
535   return request(new legacy_serialized_handler<T>(comm, source, tag, value));
536 #endif
537 }
538
539 template<typename T>
540 request request::make_serialized_array(communicator const& comm, int source, int tag, T* values, int n) {
541 #if defined(BOOST_MPI_USE_IMPROBE)
542   return request(new probe_handler<detail::serialized_array_data<T> >(comm, source, tag, values, n));
543 #else
544   return request(new legacy_serialized_array_handler<T>(comm, source, tag, values, n));
545 #endif
546 }
547
548 template<typename T, class A>
549 request request::make_dynamic_primitive_array_recv(communicator const& comm, int source, int tag, 
550                                                    std::vector<T,A>& values) {
551 #if defined(BOOST_MPI_USE_IMPROBE)
552   return request(new probe_handler<detail::dynamic_primitive_array_data<std::vector<T,A> > >(comm,source,tag,values));
553 #else
554   return request(new legacy_dynamic_primitive_array_handler<T,A>(comm, source, tag, values));
555 #endif
556 }
557
558 template<typename T>
559 request
560 request::make_trivial_send(communicator const& comm, int dest, int tag, T const* values, int n) {
561   trivial_handler* handler = new trivial_handler;
562   BOOST_MPI_CHECK_RESULT(MPI_Isend,
563                          (const_cast<T*>(values), n, 
564                           get_mpi_datatype<T>(),
565                           dest, tag, comm, &handler->m_request));
566   return request(handler);
567 }
568
569 template<typename T>
570 request
571 request::make_trivial_send(communicator const& comm, int dest, int tag, T const& value) {
572   return make_trivial_send(comm, dest, tag, &value, 1);
573 }
574
575 template<typename T>
576 request
577 request::make_trivial_recv(communicator const& comm, int dest, int tag, T* values, int n) {
578   trivial_handler* handler = new trivial_handler;
579   BOOST_MPI_CHECK_RESULT(MPI_Irecv,
580                          (values, n, 
581                           get_mpi_datatype<T>(),
582                           dest, tag, comm, &handler->m_request));
583   return request(handler);
584 }
585
586 template<typename T>
587 request
588 request::make_trivial_recv(communicator const& comm, int dest, int tag, T& value) {
589   return make_trivial_recv(comm, dest, tag, &value, 1);
590 }
591
592 template<typename T, class A>
593 request request::make_dynamic_primitive_array_send(communicator const& comm, int dest, int tag, 
594                                                    std::vector<T,A> const& values) {
595 #if defined(BOOST_MPI_USE_IMPROBE)
596   return make_trivial_send(comm, dest, tag, values.data(), values.size());
597 #else
598   {
599     // non blocking recv by legacy_dynamic_primitive_array_handler
600     // blocking recv by status recv_vector(source,tag,value,primitive)
601     boost::shared_ptr<std::size_t> size(new std::size_t(values.size()));
602     dynamic_handler* handler = new dynamic_handler;
603     request req(handler);
604     req.preserve(size);
605     
606     BOOST_MPI_CHECK_RESULT(MPI_Isend,
607                            (size.get(), 1,
608                             get_mpi_datatype(*size),
609                             dest, tag, comm, handler->m_requests+0));
610     BOOST_MPI_CHECK_RESULT(MPI_Isend,
611                            (const_cast<T*>(values.data()), *size, 
612                             get_mpi_datatype<T>(),
613                             dest, tag, comm, handler->m_requests+1));
614     return req;
615   }
616 #endif
617 }
618
619 inline
620 request::legacy_handler::legacy_handler(communicator const& comm, int source, int tag)
621   : m_comm(comm),
622     m_source(source),
623     m_tag(tag)
624 {
625   m_requests[0] = MPI_REQUEST_NULL;
626   m_requests[1] = MPI_REQUEST_NULL;
627 }
628     
629 }}
630
631 #endif // BOOST_MPI_REQUEST_HANDLERS_HPP