From 837d713edffdddce39ede0f86409c1adac92cc73 Mon Sep 17 00:00:00 2001
From: Guanzhong Chen <quantum2048@gmail.com>
Date: Wed, 8 Jan 2020 23:17:05 -0800
Subject: [PATCH] Handle default output device change by capturing from new
 device (#2)

---
 winscap.cpp | 130 ++++++++++++++++++++++++++++++++++++----------------
 1 file changed, 91 insertions(+), 39 deletions(-)

diff --git a/winscap.cpp b/winscap.cpp
index e0cdf4a..4c7286a 100644
--- a/winscap.cpp
+++ b/winscap.cpp
@@ -32,6 +32,44 @@ _COM_SMARTPTR_TYPEDEF(IMMDevice, __uuidof(IMMDevice));
 _COM_SMARTPTR_TYPEDEF(IAudioClient, __uuidof(IAudioClient));
 _COM_SMARTPTR_TYPEDEF(IAudioCaptureClient, __uuidof(IAudioCaptureClient));
 
+class DeviceChangeNotification : public IMMNotificationClient {
+    volatile ULONG ref;
+    volatile bool &changed;
+
+  public:
+    DeviceChangeNotification(volatile bool &changed) : changed(changed) {}
+
+    // This is meant to be allocated on stack, so we don't actually free.
+    STDMETHODIMP_(ULONG) AddRef() { return InterlockedIncrement(&ref); }
+    STDMETHODIMP_(ULONG) Release() { return InterlockedDecrement(&ref); }
+
+    STDMETHODIMP QueryInterface(REFIID iid, void **ppv) {
+        if (iid == IID_IUnknown) {
+            AddRef();
+            *ppv = (IUnknown *)this;
+        } else if (iid == __uuidof(IMMNotificationClient)) {
+            AddRef();
+            *ppv = (IMMNotificationClient *)this;
+        } else {
+            *ppv = nullptr;
+            return E_NOINTERFACE;
+        }
+        return S_OK;
+    }
+
+    STDMETHODIMP OnDefaultDeviceChanged(EDataFlow flow, ERole role, LPCWSTR) {
+        if (flow == eRender && role == eConsole) {
+            changed = true;
+        }
+        return S_OK;
+    }
+
+    STDMETHODIMP OnDeviceAdded(LPCWSTR) { return S_OK; }
+    STDMETHODIMP OnDeviceRemoved(LPCWSTR) { return S_OK; }
+    STDMETHODIMP OnDeviceStateChanged(LPCWSTR, DWORD) { return S_OK; }
+    STDMETHODIMP OnPropertyValueChanged(LPCWSTR, const PROPERTYKEY) { return S_OK; }
+};
+
 class EnsureCaptureStop {
     IAudioClientPtr m_client;
 
@@ -99,15 +137,6 @@ int main(int argc, char *argv[]) {
     CCoInitialize comInit;
     ensure(comInit);
 
-    IMMDeviceEnumeratorPtr pEnumerator;
-    ensure(pEnumerator.CreateInstance(__uuidof(MMDeviceEnumerator)));
-
-    IMMDevicePtr pDevice;
-    ensure(pEnumerator->GetDefaultAudioEndpoint(eRender, eConsole, &pDevice));
-
-    IAudioClientPtr pClient;
-    ensure(pDevice->Activate(__uuidof(IAudioClient), CLSCTX_ALL, nullptr, (void **)&pClient));
-
     WAVEFORMATEX wfx;
     wfx.wFormatTag = WAVE_FORMAT_PCM;
     wfx.nChannels = (WORD)channels;
@@ -116,44 +145,67 @@ int main(int argc, char *argv[]) {
     wfx.nBlockAlign = wfx.nChannels * wfx.wBitsPerSample / 8;
     wfx.nAvgBytesPerSec = wfx.nSamplesPerSec * wfx.nBlockAlign;
 
-    ensure(pClient->Initialize(AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_LOOPBACK,
-                               16 * REFTIMES_PER_MILLISEC, 0, &wfx, nullptr));
+    IMMDeviceEnumeratorPtr pEnumerator;
+    ensure(pEnumerator.CreateInstance(__uuidof(MMDeviceEnumerator)));
 
-    UINT32 bufferFrameCount;
-    ensure(pClient->GetBufferSize(&bufferFrameCount));
-
-    IAudioCaptureClientPtr pCapture;
-    ensure(pClient->GetService(__uuidof(IAudioCaptureClient), (void **)&pCapture));
-
-    DWORD dwDelay = (DWORD)(((double)REFTIMES_PER_SEC * bufferFrameCount / wfx.nSamplesPerSec) /
-                            REFTIMES_PER_MILLISEC / 2);
-
-    LPBYTE pSilence = (LPBYTE)malloc(bufferFrameCount * wfx.nBlockAlign);
-    EnsureFree freeSilence(pSilence);
-    ZeroMemory(pSilence, bufferFrameCount * wfx.nBlockAlign);
-
-    ensure(pClient->Start());
-    EnsureCaptureStop autoStop(pClient);
+    volatile bool deviceChanged = false;
+    DeviceChangeNotification deviceChangeNotification(deviceChanged);
+    pEnumerator->RegisterEndpointNotificationCallback(&deviceChangeNotification);
 
     for (;;) {
-        Sleep(dwDelay);
+        IMMDevicePtr pDevice;
+        ensure(pEnumerator->GetDefaultAudioEndpoint(eRender, eConsole, &pDevice));
 
-        UINT32 packetLength;
-        ensure(pCapture->GetNextPacketSize(&packetLength));
+        IAudioClientPtr pClient;
+        ensure(pDevice->Activate(__uuidof(IAudioClient), CLSCTX_ALL, nullptr, (void **)&pClient));
 
-        while (packetLength) {
-            LPBYTE pData;
-            UINT32 numFramesAvailable;
-            DWORD flags;
+        ensure(pClient->Initialize(AUDCLNT_SHAREMODE_SHARED, AUDCLNT_STREAMFLAGS_LOOPBACK,
+                                   16 * REFTIMES_PER_MILLISEC, 0, &wfx, nullptr));
 
-            ensure(pCapture->GetBuffer(&pData, &numFramesAvailable, &flags, nullptr, nullptr));
+        UINT32 bufferFrameCount;
+        ensure(pClient->GetBufferSize(&bufferFrameCount));
 
-            if (flags & AUDCLNT_BUFFERFLAGS_SILENT)
-                pData = pSilence;
+        IAudioCaptureClientPtr pCapture;
+        ensure(pClient->GetService(__uuidof(IAudioCaptureClient), (void **)&pCapture));
 
-            _write(_fileno(stdout), pData, wfx.nBlockAlign * numFramesAvailable);
-            ensure(pCapture->ReleaseBuffer(numFramesAvailable));
-            ensure(pCapture->GetNextPacketSize(&packetLength));
+        DWORD dwDelay = (DWORD)(((double)REFTIMES_PER_SEC * bufferFrameCount / wfx.nSamplesPerSec) /
+                                REFTIMES_PER_MILLISEC / 2);
+
+        LPBYTE pSilence = (LPBYTE)malloc(bufferFrameCount * wfx.nBlockAlign);
+        EnsureFree freeSilence(pSilence);
+        ZeroMemory(pSilence, bufferFrameCount * wfx.nBlockAlign);
+
+        ensure(pClient->Start());
+        EnsureCaptureStop autoStop(pClient);
+
+        while (!deviceChanged) {
+            Sleep(dwDelay);
+
+            UINT32 packetLength;
+            HRESULT hrNext = pCapture->GetNextPacketSize(&packetLength);
+            if (hrNext == AUDCLNT_E_DEVICE_INVALIDATED) {
+                while (!deviceChanged)
+                    ;
+                break;
+            } else {
+                ensure(hrNext);
+            }
+
+            while (packetLength) {
+                LPBYTE pData;
+                UINT32 numFramesAvailable;
+                DWORD flags;
+
+                ensure(pCapture->GetBuffer(&pData, &numFramesAvailable, &flags, nullptr, nullptr));
+
+                if (flags & AUDCLNT_BUFFERFLAGS_SILENT)
+                    pData = pSilence;
+
+                _write(_fileno(stdout), pData, wfx.nBlockAlign * numFramesAvailable);
+                ensure(pCapture->ReleaseBuffer(numFramesAvailable));
+                ensure(pCapture->GetNextPacketSize(&packetLength));
+            }
         }
+        deviceChanged = false;
     }
 }