Fix race condition when stopDiscovery() is called during onServiceFound()/onServiceLost()

This commit is contained in:
Cameron Gutman
2023-07-25 18:46:31 -05:00
parent 67b2853ef0
commit 554fee037c

View File

@@ -20,13 +20,13 @@ import java.util.concurrent.TimeUnit;
@TargetApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE) @TargetApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
public class NsdManagerDiscoveryAgent extends MdnsDiscoveryAgent { public class NsdManagerDiscoveryAgent extends MdnsDiscoveryAgent {
private static final String SERVICE_TYPE = "_nvstream._tcp"; private static final String SERVICE_TYPE = "_nvstream._tcp";
private NsdManager nsdManager; private final NsdManager nsdManager;
private boolean discoveryActive; private boolean discoveryActive;
private boolean wantsDiscoveryActive; private boolean wantsDiscoveryActive;
private final HashMap<String, NsdManager.ServiceInfoCallback> serviceCallbacks = new HashMap<>(); private final HashMap<String, NsdManager.ServiceInfoCallback> serviceCallbacks = new HashMap<>();
private final ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); private final ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 1, 0, TimeUnit.SECONDS, new LinkedBlockingQueue<>());
private NsdManager.DiscoveryListener discoveryListener = new NsdManager.DiscoveryListener() { private final NsdManager.DiscoveryListener discoveryListener = new NsdManager.DiscoveryListener() {
@Override @Override
public void onStartDiscoveryFailed(String serviceType, int errorCode) { public void onStartDiscoveryFailed(String serviceType, int errorCode) {
discoveryActive = false; discoveryActive = false;
@@ -63,41 +63,57 @@ public class NsdManagerDiscoveryAgent extends MdnsDiscoveryAgent {
@Override @Override
public void onServiceFound(NsdServiceInfo nsdServiceInfo) { public void onServiceFound(NsdServiceInfo nsdServiceInfo) {
LimeLog.info("NSD: Machine appeared: "+nsdServiceInfo.getServiceName()); // Protect against racing stopDiscovery() call
synchronized (serviceCallbacks) {
NsdManager.ServiceInfoCallback serviceInfoCallback = new NsdManager.ServiceInfoCallback() { // Bail if we've been stopped
@Override if (!wantsDiscoveryActive) {
public void onServiceInfoCallbackRegistrationFailed(int errorCode) { return;
LimeLog.severe("NSD: Service info callback registration failed: " + errorCode);
listener.notifyDiscoveryFailure(new RuntimeException("onServiceInfoCallbackRegistrationFailed(): " + errorCode));
} }
@Override LimeLog.info("NSD: Machine appeared: "+nsdServiceInfo.getServiceName());
public void onServiceUpdated(NsdServiceInfo nsdServiceInfo) {
LimeLog.info("NSD: Machine resolved: "+nsdServiceInfo.getServiceName());
reportNewComputer(nsdServiceInfo.getServiceName(), nsdServiceInfo.getPort(),
getV4Addrs(nsdServiceInfo.getHostAddresses()),
getV6Addrs(nsdServiceInfo.getHostAddresses()));
}
@Override NsdManager.ServiceInfoCallback serviceInfoCallback = new NsdManager.ServiceInfoCallback() {
public void onServiceLost() {} @Override
public void onServiceInfoCallbackRegistrationFailed(int errorCode) {
LimeLog.severe("NSD: Service info callback registration failed: " + errorCode);
listener.notifyDiscoveryFailure(new RuntimeException("onServiceInfoCallbackRegistrationFailed(): " + errorCode));
}
@Override @Override
public void onServiceInfoCallbackUnregistered() {} public void onServiceUpdated(NsdServiceInfo nsdServiceInfo) {
}; LimeLog.info("NSD: Machine resolved: "+nsdServiceInfo.getServiceName());
reportNewComputer(nsdServiceInfo.getServiceName(), nsdServiceInfo.getPort(),
getV4Addrs(nsdServiceInfo.getHostAddresses()),
getV6Addrs(nsdServiceInfo.getHostAddresses()));
}
nsdManager.registerServiceInfoCallback(nsdServiceInfo, executor, serviceInfoCallback); @Override
serviceCallbacks.put(nsdServiceInfo.getServiceName(), serviceInfoCallback); public void onServiceLost() {}
@Override
public void onServiceInfoCallbackUnregistered() {}
};
nsdManager.registerServiceInfoCallback(nsdServiceInfo, executor, serviceInfoCallback);
serviceCallbacks.put(nsdServiceInfo.getServiceName(), serviceInfoCallback);
}
} }
@Override @Override
public void onServiceLost(NsdServiceInfo nsdServiceInfo) { public void onServiceLost(NsdServiceInfo nsdServiceInfo) {
LimeLog.info("NSD: Machine lost: "+nsdServiceInfo.getServiceName()); // Protect against racing stopDiscovery() call
synchronized (serviceCallbacks) {
// Bail if we've been stopped
if (!wantsDiscoveryActive) {
return;
}
NsdManager.ServiceInfoCallback serviceInfoCallback = serviceCallbacks.remove(nsdServiceInfo.getServiceName()); LimeLog.info("NSD: Machine lost: " + nsdServiceInfo.getServiceName());
if (serviceInfoCallback != null) {
nsdManager.unregisterServiceInfoCallback(serviceInfoCallback); NsdManager.ServiceInfoCallback serviceInfoCallback = serviceCallbacks.remove(nsdServiceInfo.getServiceName());
if (serviceInfoCallback != null) {
nsdManager.unregisterServiceInfoCallback(serviceInfoCallback);
}
} }
} }
}; };
@@ -119,18 +135,21 @@ public class NsdManagerDiscoveryAgent extends MdnsDiscoveryAgent {
@Override @Override
public void stopDiscovery() { public void stopDiscovery() {
wantsDiscoveryActive = false; // Protect against racing ServiceInfoCallback and DiscoveryListener callbacks
synchronized (serviceCallbacks) {
wantsDiscoveryActive = false;
// Unregister the service discovery listener // Unregister the service discovery listener
if (discoveryActive) { if (discoveryActive) {
nsdManager.stopServiceDiscovery(discoveryListener); nsdManager.stopServiceDiscovery(discoveryListener);
} }
// Unregister all service info callbacks // Unregister all service info callbacks
for (NsdManager.ServiceInfoCallback callback : serviceCallbacks.values()) { for (NsdManager.ServiceInfoCallback callback : serviceCallbacks.values()) {
nsdManager.unregisterServiceInfoCallback(callback); nsdManager.unregisterServiceInfoCallback(callback);
}
serviceCallbacks.clear();
} }
serviceCallbacks.clear();
} }
private static Inet4Address[] getV4Addrs(List<InetAddress> addrs) { private static Inet4Address[] getV4Addrs(List<InetAddress> addrs) {