Initial commit
[platform/core/ml/aitt.git] / android / modules / webrtc / src / main / java / com / samsung / android / modules / webrtc / WebRTC.java
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.samsung.android.modules.webrtc;
17
18 import android.content.Context;
19 import android.os.SystemClock;
20 import android.util.Log;
21
22 import org.json.JSONException;
23 import org.json.JSONObject;
24 import org.webrtc.CapturerObserver;
25 import org.webrtc.DataChannel;
26 import org.webrtc.DefaultVideoDecoderFactory;
27 import org.webrtc.DefaultVideoEncoderFactory;
28 import org.webrtc.EglBase;
29 import org.webrtc.IceCandidate;
30 import org.webrtc.MediaConstraints;
31 import org.webrtc.MediaStream;
32 import org.webrtc.MediaStreamTrack;
33 import org.webrtc.NV21Buffer;
34 import org.webrtc.PeerConnection;
35 import org.webrtc.PeerConnectionFactory;
36 import org.webrtc.RtpReceiver;
37 import org.webrtc.SdpObserver;
38 import org.webrtc.SessionDescription;
39 import org.webrtc.SurfaceTextureHelper;
40 import org.webrtc.VideoCapturer;
41 import org.webrtc.VideoDecoderFactory;
42 import org.webrtc.VideoEncoderFactory;
43 import org.webrtc.VideoFrame;
44 import org.webrtc.VideoSink;
45 import org.webrtc.VideoSource;
46 import org.webrtc.VideoTrack;
47 import static org.webrtc.SessionDescription.Type.ANSWER;
48 import static org.webrtc.SessionDescription.Type.OFFER;
49
50 import java.io.IOException;
51 import java.io.ObjectInputStream;
52 import java.io.ObjectOutputStream;
53 import java.io.Serializable;
54 import java.net.Socket;
55 import java.nio.ByteBuffer;
56 import java.util.ArrayList;
57 import java.util.concurrent.TimeUnit;
58
59 /**
60  * WebRTC class to implement webRTC functionalities
61  */
62 public class WebRTC {
63     private static final String TAG = "WebRTC";
64     public static final String VIDEO_TRACK_ID = "ARDAMSv0";
65     private static final String CANDIDATE = "candidate";
66     private java.net.Socket socket;
67     private boolean isInitiator;
68     private boolean isChannelReady;
69     private boolean isStarted;
70     private boolean isReciever;
71     private PeerConnection peerConnection;
72     private PeerConnectionFactory factory;
73     private VideoTrack videoTrackFromSource;
74     private ObjectOutputStream outStream;
75     private ObjectInputStream inputStream;
76     private SDPThread sdpThread;
77     private Context appContext;
78     private DataChannel localDataChannel;
79     private FrameVideoCapturer videoCapturer;
80     private ReceiveDataCallback dataCallback;
81     private String recieverIP;
82     private Integer recieverPort;
83
84     /**
85      * WebRTC channels supported - Media channel, data channel
86      */
87     public enum DataType{
88         MESSAGE,
89         VIDEOFRAME,
90     }
91
92     /**
93      * WebRTC constructor to create webRTC instance
94      * @param dataType To decide webRTC channel type
95      * @param appContext Application context creating webRTC instance
96      */
97     public WebRTC(DataType dataType , Context appContext) {
98         this.appContext = appContext;
99         this.isReciever = false;
100     }
101
102     /**
103      * WebRTC constructor to create webRTC instance
104      * @param dataType To decide webRTC channel type
105      * @param appContext Application context creating webRTC instance
106      * @param socket Java server socket for webrtc signalling
107      */
108     WebRTC(DataType dataType , Context appContext , Socket socket) {
109         Log.d(TAG , "InWebRTC Constructor");
110         this.appContext = appContext;
111         this.socket = socket;
112         this.isReciever = true;
113     }
114
115     /**
116      * To create data call-back mechanism
117      * @param cb aitt callback registered to receive a webrtc data
118      */
119     public void registerDataCallback(ReceiveDataCallback cb){
120         this.dataCallback = cb;
121     }
122
123     /**
124      * Method to disconnect the connection from peer
125      */
126     public void disconnect() {
127         if (sdpThread != null) {
128             sdpThread.stop();
129         }
130
131         if (socket != null) {
132             new Thread(() -> {
133                 try {
134                     sendMessage(false, "bye");
135                     socket.close();
136                     if (outStream != null) {
137                         outStream.close();
138                     }
139                     if (inputStream != null) {
140                         inputStream.close();
141                     }
142                 } catch (IOException e) {
143                     Log.e(TAG, "Error during disconnect", e);
144                 }
145             }).start();
146         }
147     }
148
149     /**
150      * Method to establish a socket connection with peer node
151      */
152     public void connect() {
153         initialize();
154     }
155
156     /**
157      * Method to establish communication with peer node
158      * @param recieverIP IP Address of the destination(peer) node
159      * @param recieverPort Port number of the destination(peer) node
160      */
161     public void connect(String recieverIP , Integer recieverPort){
162         this.recieverIP = recieverIP;
163         this.recieverPort = recieverPort;
164         initialize();
165     }
166
167     /**
168      * Method to initialize webRTC APIs while establishing connection
169      */
170     private void initialize(){
171         initializePeerConnectionFactory();
172         initializePeerConnections();
173         if(!isReciever){
174             createVideoTrack();
175             addVideoTrack();
176         }
177         isInitiator = isReciever;
178
179         sdpThread = new SDPThread();
180         new Thread(sdpThread).start();
181     }
182
183     /**
184      * Method to create webRTC offer for sdp negotiation
185      */
186     private void doCall() {
187         MediaConstraints sdpMediaConstraints = new MediaConstraints();
188         sdpMediaConstraints.mandatory.add(new MediaConstraints.KeyValuePair("OfferToReceiveVideo", "true"));
189
190         peerConnection.createOffer(new SimpleSdpObserver() {
191             @Override
192             public void onCreateSuccess(SessionDescription sessionDescription) {
193                 Log.d(TAG, "onCreateSuccess: ");
194                 peerConnection.setLocalDescription(new SimpleSdpObserver(), sessionDescription);
195                 JSONObject message = new JSONObject();
196                 try {
197                     message.put("type", "offer");
198                     message.put("sdp", sessionDescription.description);
199                     sendMessage(true , message);
200                 } catch (JSONException | IOException e) {
201                     Log.e(TAG, "Error during create offer", e);
202                 }
203             }
204         }, sdpMediaConstraints);
205     }
206
207     /**
208      * Method to send signalling messages over socket connection
209      * @param isJSON Boolean to check if message is JSON
210      * @param message Data to be sent over webRTC connection
211      * @throws IOException Throws IOException if writing to outStream fails
212      */
213     private void sendMessage(boolean isJSON, Object message) throws IOException {
214         Log.d(TAG, message.toString());
215         if (outStream != null) {
216             if (isJSON) {
217                 outStream.writeObject(new Packet((JSONObject) message));
218             } else {
219                 outStream.writeObject(new Packet((String) message));
220             }
221         }
222     }
223
224     /**
225      * Class to create proxy video sink
226      */
227     private static class ProxyVideoSink implements VideoSink {
228
229         private ReceiveDataCallback dataCallback;
230
231         /**
232          * ProxyVideoSink constructor to create its instance
233          * @param dataCb DataCall back to be set to self-object
234          */
235         ProxyVideoSink(ReceiveDataCallback dataCb){
236             this.dataCallback = dataCb;
237         }
238
239         /**
240          * Method to send data through data call back
241          * @param frame VideoFrame to be transferred using media channel
242          */
243         @Override
244         synchronized public void onFrame(VideoFrame frame) {
245             byte[] rawFrame = createNV21Data(frame.getBuffer().toI420());
246             dataCallback.pushData(rawFrame);
247         }
248
249         /**
250          * Method used to convert VideoFrame to NV21 data format
251          * @param i420Buffer VideoFrame in I420 buffer format
252          * @return the video frame in NV21 data format
253          */
254         public byte[] createNV21Data(VideoFrame.I420Buffer i420Buffer) {
255             final int width = i420Buffer.getWidth();
256             final int height = i420Buffer.getHeight();
257             final int chromaStride = width;
258             final int chromaWidth = (width + 1) / 2;
259             final int chromaHeight = (height + 1) / 2;
260             final int ySize = width * height;
261             final ByteBuffer nv21Buffer = ByteBuffer.allocateDirect(ySize + chromaStride * chromaHeight);
262             final byte[] nv21Data = nv21Buffer.array();
263             for (int y = 0; y < height; ++y) {
264                 for (int x = 0; x < width; ++x) {
265                     final byte yValue = i420Buffer.getDataY().get(y * i420Buffer.getStrideY() + x);
266                     nv21Data[y * width + x] = yValue;
267                 }
268             }
269             for (int y = 0; y < chromaHeight; ++y) {
270                 for (int x = 0; x < chromaWidth; ++x) {
271                     final byte uValue = i420Buffer.getDataU().get(y * i420Buffer.getStrideU() + x);
272                     final byte vValue = i420Buffer.getDataV().get(y * i420Buffer.getStrideV() + x);
273                     nv21Data[ySize + y * chromaStride + 2 * x + 0] = vValue;
274                     nv21Data[ySize + y * chromaStride + 2 * x + 1] = uValue;
275                 }
276             }
277             return nv21Data;
278         }
279     }
280
281     /**
282      * Method to initialize peer connection factory
283      */
284     private void initializePeerConnectionFactory() {
285         EglBase mRootEglBase;
286         mRootEglBase = EglBase.create();
287         VideoEncoderFactory encoderFactory = new DefaultVideoEncoderFactory(mRootEglBase.getEglBaseContext(), true /* enableIntelVp8Encoder */, true);
288         VideoDecoderFactory decoderFactory = new DefaultVideoDecoderFactory(mRootEglBase.getEglBaseContext());
289
290         PeerConnectionFactory.initialize(PeerConnectionFactory.InitializationOptions.builder(appContext).setEnableInternalTracer(true).createInitializationOptions());
291         PeerConnectionFactory.Builder builder = PeerConnectionFactory.builder().setVideoEncoderFactory(encoderFactory).setVideoDecoderFactory(decoderFactory);
292         builder.setOptions(null);
293         factory = builder.createPeerConnectionFactory();
294     }
295
296     /**
297      * Method to create video track
298      */
299     private void createVideoTrack(){
300         videoCapturer = new FrameVideoCapturer();
301         VideoSource videoSource = factory.createVideoSource(false);
302         videoCapturer.initialize(null , null ,videoSource.getCapturerObserver());
303         videoTrackFromSource = factory.createVideoTrack(VIDEO_TRACK_ID, videoSource);
304         videoTrackFromSource.setEnabled(true);
305     }
306
307     /**
308      * Method to initialize peer connections
309      */
310     private void initializePeerConnections() {
311         peerConnection = createPeerConnection(factory);
312         if (peerConnection != null) {
313             localDataChannel = peerConnection.createDataChannel("sendDataChannel", new DataChannel.Init());
314         }
315     }
316
317     /**
318      * Method to add video track
319      */
320     private void addVideoTrack() {
321         MediaStream mediaStream = factory.createLocalMediaStream("ARDAMS");
322         mediaStream.addTrack(videoTrackFromSource);
323         if(peerConnection!=null){
324             peerConnection.addStream(mediaStream);
325         }
326     }
327
328     /**
329      * Method to create peer connection
330      * @param factory Peer connection factory object
331      * @return return factory object
332      */
333     private PeerConnection createPeerConnection(PeerConnectionFactory factory) {
334         PeerConnection.RTCConfiguration rtcConfig = new PeerConnection.RTCConfiguration(new ArrayList<>());
335         MediaConstraints pcConstraints = new MediaConstraints();
336
337         PeerConnection.Observer pcObserver = new PeerConnection.Observer() {
338             @Override
339             public void onSignalingChange(PeerConnection.SignalingState signalingState) {
340                 Log.d(TAG, "onSignalingChange: ");
341             }
342
343             @Override
344             public void onIceConnectionChange(PeerConnection.IceConnectionState iceConnectionState) {
345                 Log.d(TAG, "onIceConnectionChange: ");
346             }
347
348             @Override
349             public void onIceConnectionReceivingChange(boolean b) {
350                 Log.d(TAG, "onIceConnectionReceivingChange: ");
351             }
352
353             @Override
354             public void onIceGatheringChange(PeerConnection.IceGatheringState iceGatheringState) {
355                 Log.d(TAG, "onIceGatheringChange: ");
356             }
357
358             @Override
359             public void onIceCandidate(IceCandidate iceCandidate) {
360                 Log.d(TAG, "onIceCandidate: ");
361                 JSONObject message = new JSONObject();
362                 try {
363                     message.put("type", CANDIDATE);
364                     message.put("label", iceCandidate.sdpMLineIndex);
365                     message.put("id", iceCandidate.sdpMid);
366                     message.put(CANDIDATE, iceCandidate.sdp);
367                     Log.d(TAG, "onIceCandidate: sending candidate " + message);
368                     sendMessage(true , message);
369                 } catch (JSONException | IOException e) {
370                     Log.e(TAG, "Error during onIceCandidate", e);
371                 }
372             }
373
374             @Override
375             public void onIceCandidatesRemoved(IceCandidate[] iceCandidates) {
376                 Log.d(TAG, "onIceCandidatesRemoved: ");
377             }
378
379             @Override
380             public void onAddStream(MediaStream mediaStream) {
381                 Log.d(TAG, "onAddStream: " + mediaStream.videoTracks.size());
382                 VideoTrack remoteVideoTrack = mediaStream.videoTracks.get(0);
383                 remoteVideoTrack.setEnabled(true);
384             }
385
386             @Override
387             public void onRemoveStream(MediaStream mediaStream) {
388                 Log.d(TAG, "onRemoveStream: ");
389             }
390
391             @Override
392             public void onDataChannel(DataChannel dataChannel) {
393                 Log.d(TAG, "onDataChannel: ");
394                 dataChannel.registerObserver(new DataChannel.Observer() {
395                     @Override
396                     public void onBufferedAmountChange(long l) {
397                         //Keep this callback for future usage
398                         Log.d(TAG, "onBufferedAmountChange:");
399                     }
400
401                     @Override
402                     public void onStateChange() {
403                         Log.d(TAG, "onStateChange: remote data channel state: " + dataChannel.state().toString());
404                     }
405
406                     @Override
407                     public void onMessage(DataChannel.Buffer buffer) {
408                         Log.d(TAG, "onMessage: got message");
409                         dataCallback.pushData(readIncomingMessage(buffer.data));
410                     }
411                 });
412             }
413
414             @Override
415             public void onRenegotiationNeeded() {
416                 Log.d(TAG, "onRenegotiationNeeded: ");
417             }
418
419             @Override
420             public void onAddTrack(RtpReceiver rtpReceiver, MediaStream[] mediaStreams) {
421                 MediaStreamTrack track = rtpReceiver.track();
422                 if (track instanceof VideoTrack && isReciever) {
423                     Log.i(TAG, "onAddVideoTrack");
424                     VideoTrack remoteVideoTrack = (VideoTrack) track;
425                     remoteVideoTrack.setEnabled(true);
426                     ProxyVideoSink  videoSink = new ProxyVideoSink(dataCallback);
427                     remoteVideoTrack.addSink(videoSink);
428                 }
429             }
430         };
431         return factory.createPeerConnection(rtcConfig, pcConstraints, pcObserver);
432     }
433
434     /**
435      * Method used to send video data
436      * @param frame Video frame in byte format
437      * @param width width of the video frame
438      * @param height height of the video frame
439      */
440     public void sendVideoData(byte[] frame , int width , int height){
441         videoCapturer.send(frame , width , height);
442     }
443
444     /**
445      * Method to send message data
446      * @param message message to be sent in byte format
447      */
448     public void sendMessageData(byte[] message) {
449         ByteBuffer data = ByteBuffer.wrap(message);
450         localDataChannel.send(new DataChannel.Buffer(data, false));
451     }
452
453     /**
454      * Interface to create data call back mechanism
455      */
456     public interface ReceiveDataCallback{
457         void pushData(byte[] frame);
458     }
459
460     /**
461      * Class packet to create a packet
462      */
463     private static class Packet implements Serializable {
464         boolean isString;
465         String obj;
466         Packet(String s){
467             isString = true;
468             obj = s;
469         }
470
471         Packet(JSONObject json){
472             isString = false;
473             obj = json.toString();
474         }
475     }
476
477     /**
478      * Method to read incoming message and convert it to byte format
479      * @param buffer Message incoming in Byte buffer format
480      * @return returns byteBuffer message in byte format
481      */
482     private byte[] readIncomingMessage(ByteBuffer buffer) {
483         byte[] bytes;
484         if (buffer.hasArray()) {
485             bytes = buffer.array();
486         } else {
487             bytes = new byte[buffer.remaining()];
488             buffer.get(bytes);
489         }
490         return bytes;
491     }
492
493     /**
494      * Class to implement SDP observer
495      */
496     private static class SimpleSdpObserver implements SdpObserver {
497         @Override
498         public void onCreateSuccess(SessionDescription sessionDescription) {
499             //Required for future reference
500         }
501
502         @Override
503         public void onSetSuccess() {
504             Log.d(TAG, "onSetSuccess:");
505         }
506
507         @Override
508         public void onCreateFailure(String s) {
509             Log.d(TAG, "onCreateFailure: Reason = " + s);
510         }
511
512         @Override
513         public void onSetFailure(String s) {
514             Log.d(TAG, "onSetFailure: Reason = " + s);
515         }
516     }
517
518     /**
519      * Class to implement Frame video capturer
520      */
521     private static class FrameVideoCapturer implements VideoCapturer {
522         private CapturerObserver capturerObserver;
523
524         void send(byte[] frame, int width, int height) {
525             long timestampNS = TimeUnit.MILLISECONDS.toNanos(SystemClock.elapsedRealtime());
526             NV21Buffer buffer = new NV21Buffer(frame, width, height, null);
527             VideoFrame videoFrame = new VideoFrame(buffer, 0, timestampNS);
528             this.capturerObserver.onFrameCaptured(videoFrame);
529             videoFrame.release();
530         }
531
532         @Override
533         public void initialize(SurfaceTextureHelper surfaceTextureHelper, Context context, CapturerObserver capturerObserver) {
534             this.capturerObserver = capturerObserver;
535         }
536
537         public void startCapture(int width, int height, int framerate) {
538             //Required for future reference
539         }
540
541         public void stopCapture() throws InterruptedException {
542             //Required for future reference
543         }
544
545         public void changeCaptureFormat(int width, int height, int framerate) {
546             //Required for future reference
547         }
548
549         public void dispose() {
550             //Required for future reference
551         }
552
553         public boolean isScreencast() {
554             return false;
555         }
556     }
557
558     /**
559      * Class to implement SDP thread
560      */
561     private class SDPThread implements Runnable {
562         private volatile boolean isRunning = true;
563
564         @Override
565         public void run() {
566             isChannelReady = true;
567
568             createSocket();
569             invokeSendMessage();
570
571             while (isRunning) {
572                 try {
573                     Packet recvPacketNew = (Packet) inputStream.readObject();
574                     if (recvPacketNew != null) {
575                         if (recvPacketNew.isString) {
576                             String message = recvPacketNew.obj;
577                             checkPacketMessage(message);
578                         } else {
579                             JSONObject message = new JSONObject(recvPacketNew.obj);
580                             Log.d(TAG, "connectToSignallingServer: got message " + message);
581                             decodeMessage(message);
582                         }
583                     }
584                 } catch (ClassNotFoundException | JSONException | IOException e) {
585                     isRunning = false;
586                     Log.e(TAG, "Error during JSON read", e);
587                 }
588             }
589         }
590
591         /**
592          * Method to decode message
593          * @param message Message received in JSON object format
594          */
595         private void decodeMessage(JSONObject message) {
596             try {
597                 if (message.getString("type").equals("offer")) {
598                     Log.d(TAG, "connectToSignallingServer: received an offer " + isInitiator + " " + isStarted);
599                     invokeMaybeStart();
600                     peerConnection.setRemoteDescription(new SimpleSdpObserver(), new SessionDescription(OFFER, message.getString("sdp")));
601                     doAnswer();
602                 } else if (message.getString("type").equals("answer") && isStarted) {
603                     peerConnection.setRemoteDescription(new SimpleSdpObserver(), new SessionDescription(ANSWER, message.getString("sdp")));
604                 } else if (message.getString("type").equals(CANDIDATE) && isStarted) {
605                     Log.d(TAG, "connectToSignallingServer: receiving candidates");
606                     IceCandidate candidate = new IceCandidate(message.getString("id"), message.getInt("label"), message.getString(CANDIDATE));
607                     peerConnection.addIceCandidate(candidate);
608                 }
609             } catch (JSONException e) {
610                 Log.e(TAG, "Error during message decoding", e);
611             }
612         }
613
614         /**
615          * Method to create SDP answer for a given SDP offer
616          */
617         private void doAnswer() {
618             peerConnection.createAnswer(new SimpleSdpObserver() {
619                 @Override
620                 public void onCreateSuccess(SessionDescription sessionDescription) {
621                     peerConnection.setLocalDescription(new SimpleSdpObserver(), sessionDescription);
622                     JSONObject message = new JSONObject();
623                     try {
624                         message.put("type", "answer");
625                         message.put("sdp", sessionDescription.description);
626                         sendMessage(true, message);
627                     } catch (JSONException | IOException e) {
628                         Log.e(TAG, "Error during sdp answer", e);
629                     }
630                 }
631             }, new MediaConstraints());
632         }
633
634         /**
635          * Method used to create a socket for SDP negotiation
636          */
637         private void createSocket(){
638             try {
639                 if(!isReciever){
640                     socket = new Socket(recieverIP, recieverPort);
641                 }
642                 outStream = new ObjectOutputStream(socket.getOutputStream());
643                 inputStream = new ObjectInputStream(socket.getInputStream());
644             } catch (Exception e) {
645                 Log.e(TAG, "Error during create socket", e);
646             }
647         }
648
649         /**
650          * Method to invoke Signalling handshake message
651          */
652         private void invokeSendMessage(){
653             try {
654                 sendMessage(false , "got user media");
655             } catch (Exception e) {
656                 Log.e(TAG, "Error during invoke send message", e);
657             }
658         }
659
660         /**
661          * Method to check if the message in received packet is "got user media"
662          */
663         private void checkPacketMessage(String message){
664             if (message.equals("got user media")) {
665                 maybeStart();
666             }
667         }
668
669         /**
670          * Method to invoke MaybeStart()
671          */
672         private void invokeMaybeStart(){
673             if (!isInitiator && !isStarted) {
674                 maybeStart();
675             }
676         }
677
678         /**
679          * Method to begin SDP negotiation by sending SDP offer to peer
680          */
681         private void maybeStart() {
682             Log.d(TAG, "maybeStart: " + isStarted + " " + isChannelReady);
683             if (!isStarted && isChannelReady) {
684                 isStarted = true;
685                 if (isInitiator) {
686                     doCall();
687                 }
688             }
689         }
690
691         /**
692          * Method to stop thread
693          */
694         public void stop() {
695             isRunning = false;
696         }
697     }
698 }