2828import com .google .cloud .spanner .MockSpannerServiceImpl .SimulatedExecutionTime ;
2929import com .google .cloud .spanner .connection .AbstractMockServerTest ;
3030import com .google .common .collect .ImmutableSet ;
31+ import com .google .cloud .spanner .XGoogSpannerRequestId ;
3132import com .google .spanner .v1 .BatchCreateSessionsRequest ;
3233import com .google .spanner .v1 .BeginTransactionRequest ;
3334import com .google .spanner .v1 .ExecuteSqlRequest ;
6465@ RunWith (JUnit4 .class )
6566public class RetryOnDifferentGrpcChannelMockServerTest extends AbstractMockServerTest {
6667 private static final Map <String , Set <InetSocketAddress >> SERVER_ADDRESSES = new HashMap <>();
68+ private static final Map <String , Set <Long >> CHANNEL_HINTS = new HashMap <>();
6769
6870 @ BeforeClass
6971 public static void startStaticServer () throws IOException {
@@ -79,6 +81,7 @@ public static void removeSystemProperty() {
7981 @ After
8082 public void clearRequests () {
8183 SERVER_ADDRESSES .clear ();
84+ CHANNEL_HINTS .clear ();
8285 mockSpanner .clearRequests ();
8386 mockSpanner .removeAllExecutionTimes ();
8487 }
@@ -91,6 +94,7 @@ public <ReqT, RespT> Listener<ReqT> interceptCall(
9194 Metadata metadata ,
9295 ServerCallHandler <ReqT , RespT > serverCallHandler ) {
9396 Attributes attributes = serverCall .getAttributes ();
97+ String methodName = serverCall .getMethodDescriptor ().getFullMethodName ();
9498 //noinspection unchecked,deprecation
9599 Attributes .Key <InetSocketAddress > key =
96100 (Attributes .Key <InetSocketAddress >)
@@ -103,10 +107,27 @@ public <ReqT, RespT> Listener<ReqT> interceptCall(
103107 synchronized (SERVER_ADDRESSES ) {
104108 Set <InetSocketAddress > addresses =
105109 SERVER_ADDRESSES .getOrDefault (
106- serverCall . getMethodDescriptor (). getFullMethodName () , new HashSet <>());
110+ methodName , new HashSet <>());
107111 addresses .add (address );
108112 SERVER_ADDRESSES .putIfAbsent (
109- serverCall .getMethodDescriptor ().getFullMethodName (), addresses );
113+ methodName , addresses );
114+ }
115+ }
116+ String requestId = metadata .get (XGoogSpannerRequestId .REQUEST_HEADER_KEY );
117+ if (requestId != null ) {
118+ // REQUEST_ID format: version.randProcessId.nthClientId.nthChannelId.nthRequest.attempt
119+ String [] parts = requestId .split ("\\ ." );
120+ if (parts .length >= 6 ) {
121+ try {
122+ long channelHint = Long .parseLong (parts [3 ]);
123+ synchronized (CHANNEL_HINTS ) {
124+ Set <Long > hints = CHANNEL_HINTS .getOrDefault (methodName , new HashSet <>());
125+ hints .add (channelHint );
126+ CHANNEL_HINTS .putIfAbsent (methodName , hints );
127+ }
128+ } catch (NumberFormatException ignore ) {
129+ // Ignore malformed header values in tests.
130+ }
110131 }
111132 }
112133 return serverCallHandler .startCall (serverCall , metadata );
@@ -157,8 +178,8 @@ public void testReadWriteTransaction_retriesOnNewChannel() {
157178 assertNotEquals (requests .get (0 ).getSession (), requests .get (1 ).getSession ());
158179 assertEquals (
159180 2 ,
160- SERVER_ADDRESSES
161- .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , ImmutableSet . of ())
181+ CHANNEL_HINTS
182+ .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , new HashSet <> ())
162183 .size ());
163184 }
164185
@@ -201,8 +222,8 @@ public void testReadWriteTransaction_stopsRetrying() {
201222 assertEquals (numChannels , sessions .size ());
202223 assertEquals (
203224 numChannels ,
204- SERVER_ADDRESSES
205- .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , ImmutableSet . of ())
225+ CHANNEL_HINTS
226+ .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , new HashSet <> ())
206227 .size ());
207228 }
208229 }
@@ -275,8 +296,8 @@ public void testDenyListedChannelIsCleared() {
275296 assertEquals (numChannels + 1 , sessions .size ());
276297 assertEquals (
277298 numChannels ,
278- SERVER_ADDRESSES
279- .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , ImmutableSet . of ())
299+ CHANNEL_HINTS
300+ .getOrDefault ("google.spanner.v1.Spanner/BeginTransaction" , new HashSet <> ())
280301 .size ());
281302 assertEquals (numChannels , mockSpanner .countRequestsOfType (BatchCreateSessionsRequest .class ));
282303 }
@@ -303,11 +324,11 @@ public void testSingleUseQuery_retriesOnNewChannel() {
303324 List <ExecuteSqlRequest > requests = mockSpanner .getRequestsOfType (ExecuteSqlRequest .class );
304325 // The requests use the same multiplexed session.
305326 assertEquals (requests .get (0 ).getSession (), requests .get (1 ).getSession ());
306- // The requests use two different gRPC channels .
327+ // The requests use two different channel hints (which may map to same physical channel) .
307328 assertEquals (
308329 2 ,
309- SERVER_ADDRESSES
310- .getOrDefault ("google.spanner.v1.Spanner/ExecuteStreamingSql" , ImmutableSet . of ())
330+ CHANNEL_HINTS
331+ .getOrDefault ("google.spanner.v1.Spanner/ExecuteStreamingSql" , new HashSet <> ())
311332 .size ());
312333 }
313334
@@ -327,19 +348,19 @@ public void testSingleUseQuery_stopsRetrying() {
327348 assertEquals (ErrorCode .DEADLINE_EXCEEDED , exception .getErrorCode ());
328349 }
329350 int numChannels = spanner .getOptions ().getNumChannels ();
330- assertEquals (numChannels , mockSpanner .countRequestsOfType (ExecuteSqlRequest .class ));
331351 List <ExecuteSqlRequest > requests = mockSpanner .getRequestsOfType (ExecuteSqlRequest .class );
332352 // The requests use the same multiplexed session.
333353 String session = requests .get (0 ).getSession ();
334354 for (ExecuteSqlRequest request : requests ) {
335355 assertEquals (session , request .getSession ());
336356 }
337- // The requests use all gRPC channels.
338- assertEquals (
339- numChannels ,
340- SERVER_ADDRESSES
341- .getOrDefault ("google.spanner.v1.Spanner/ExecuteStreamingSql" , ImmutableSet .of ())
342- .size ());
357+ // Each attempt, including retries, must use a distinct channel hint.
358+ int totalRequests = mockSpanner .countRequestsOfType (ExecuteSqlRequest .class );
359+ int distinctHints =
360+ CHANNEL_HINTS
361+ .getOrDefault ("google.spanner.v1.Spanner/ExecuteStreamingSql" , new HashSet <>())
362+ .size ();
363+ assertEquals (totalRequests , distinctHints );
343364 }
344365 }
345366
0 commit comments