001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.hadoop.hbase.client.coprocessor;
019
020import static org.apache.hadoop.hbase.client.coprocessor.AggregationHelper.getParsedGenericInstance;
021import static org.apache.hadoop.hbase.client.coprocessor.AggregationHelper.validateArgAndGetPB;
022import static org.apache.hadoop.hbase.util.FutureUtils.addListener;
023
024import java.io.IOException;
025import java.util.Map;
026import java.util.NavigableMap;
027import java.util.NavigableSet;
028import java.util.NoSuchElementException;
029import java.util.TreeMap;
030import java.util.concurrent.CompletableFuture;
031import org.apache.hadoop.hbase.Cell;
032import org.apache.hadoop.hbase.HConstants;
033import org.apache.hadoop.hbase.client.AdvancedScanResultConsumer;
034import org.apache.hadoop.hbase.client.AsyncTable;
035import org.apache.hadoop.hbase.client.AsyncTable.CoprocessorCallback;
036import org.apache.hadoop.hbase.client.RegionInfo;
037import org.apache.hadoop.hbase.client.Result;
038import org.apache.hadoop.hbase.client.Scan;
039import org.apache.hadoop.hbase.coprocessor.ColumnInterpreter;
040import org.apache.hadoop.hbase.util.Bytes;
041import org.apache.hadoop.hbase.util.ReflectionUtils;
042import org.apache.yetus.audience.InterfaceAudience;
043
044import org.apache.hbase.thirdparty.com.google.protobuf.Message;
045
046import org.apache.hadoop.hbase.shaded.protobuf.generated.AggregateProtos.AggregateRequest;
047import org.apache.hadoop.hbase.shaded.protobuf.generated.AggregateProtos.AggregateResponse;
048import org.apache.hadoop.hbase.shaded.protobuf.generated.AggregateProtos.AggregateService;
049
050/**
051 * This client class is for invoking the aggregate functions deployed on the Region Server side via
052 * the AggregateService. This class will implement the supporting functionality for
053 * summing/processing the individual results obtained from the AggregateService for each region.
054 */
055@InterfaceAudience.Public
056public final class AsyncAggregationClient {
057  private AsyncAggregationClient() {}
058
059  private static abstract class AbstractAggregationCallback<T>
060      implements CoprocessorCallback<AggregateResponse> {
061    private final CompletableFuture<T> future;
062
063    protected boolean finished = false;
064
065    private void completeExceptionally(Throwable error) {
066      if (finished) {
067        return;
068      }
069      finished = true;
070      future.completeExceptionally(error);
071    }
072
073    protected AbstractAggregationCallback(CompletableFuture<T> future) {
074      this.future = future;
075    }
076
077    @Override
078    public synchronized void onRegionError(RegionInfo region, Throwable error) {
079      completeExceptionally(error);
080    }
081
082    @Override
083    public synchronized void onError(Throwable error) {
084      completeExceptionally(error);
085    }
086
087    protected abstract void aggregate(RegionInfo region, AggregateResponse resp)
088        throws IOException;
089
090    @Override
091    public synchronized void onRegionComplete(RegionInfo region, AggregateResponse resp) {
092      try {
093        aggregate(region, resp);
094      } catch (IOException e) {
095        completeExceptionally(e);
096      }
097    }
098
099    protected abstract T getFinalResult();
100
101    @Override
102    public synchronized void onComplete() {
103      if (finished) {
104        return;
105      }
106      finished = true;
107      future.complete(getFinalResult());
108    }
109  }
110
111  private static <R, S, P extends Message, Q extends Message, T extends Message> R
112      getCellValueFromProto(ColumnInterpreter<R, S, P, Q, T> ci, AggregateResponse resp,
113          int firstPartIndex) throws IOException {
114    Q q = getParsedGenericInstance(ci.getClass(), 3, resp.getFirstPart(firstPartIndex));
115    return ci.getCellValueFromProto(q);
116  }
117
118  private static <R, S, P extends Message, Q extends Message, T extends Message> S
119      getPromotedValueFromProto(ColumnInterpreter<R, S, P, Q, T> ci, AggregateResponse resp,
120          int firstPartIndex) throws IOException {
121    T t = getParsedGenericInstance(ci.getClass(), 4, resp.getFirstPart(firstPartIndex));
122    return ci.getPromotedValueFromProto(t);
123  }
124
125  private static byte[] nullToEmpty(byte[] b) {
126    return b != null ? b : HConstants.EMPTY_BYTE_ARRAY;
127  }
128
129  public static <R, S, P extends Message, Q extends Message, T extends Message> CompletableFuture<R>
130      max(AsyncTable<?> table, ColumnInterpreter<R, S, P, Q, T> ci, Scan scan) {
131    CompletableFuture<R> future = new CompletableFuture<>();
132    AggregateRequest req;
133    try {
134      req = validateArgAndGetPB(scan, ci, false);
135    } catch (IOException e) {
136      future.completeExceptionally(e);
137      return future;
138    }
139    AbstractAggregationCallback<R> callback = new AbstractAggregationCallback<R>(future) {
140
141      private R max;
142
143      @Override
144      protected void aggregate(RegionInfo region, AggregateResponse resp) throws IOException {
145        if (resp.getFirstPartCount() > 0) {
146          R result = getCellValueFromProto(ci, resp, 0);
147          if (max == null || (result != null && ci.compare(max, result) < 0)) {
148            max = result;
149          }
150        }
151      }
152
153      @Override
154      protected R getFinalResult() {
155        return max;
156      }
157    };
158    table
159        .<AggregateService, AggregateResponse> coprocessorService(AggregateService::newStub,
160          (stub, controller, rpcCallback) -> stub.getMax(controller, req, rpcCallback), callback)
161        .fromRow(nullToEmpty(scan.getStartRow()), scan.includeStartRow())
162        .toRow(nullToEmpty(scan.getStopRow()), scan.includeStopRow()).execute();
163    return future;
164  }
165
166  public static <R, S, P extends Message, Q extends Message, T extends Message> CompletableFuture<R>
167      min(AsyncTable<?> table, ColumnInterpreter<R, S, P, Q, T> ci, Scan scan) {
168    CompletableFuture<R> future = new CompletableFuture<>();
169    AggregateRequest req;
170    try {
171      req = validateArgAndGetPB(scan, ci, false);
172    } catch (IOException e) {
173      future.completeExceptionally(e);
174      return future;
175    }
176
177    AbstractAggregationCallback<R> callback = new AbstractAggregationCallback<R>(future) {
178
179      private R min;
180
181      @Override
182      protected void aggregate(RegionInfo region, AggregateResponse resp) throws IOException {
183        if (resp.getFirstPartCount() > 0) {
184          R result = getCellValueFromProto(ci, resp, 0);
185          if (min == null || (result != null && ci.compare(min, result) > 0)) {
186            min = result;
187          }
188        }
189      }
190
191      @Override
192      protected R getFinalResult() {
193        return min;
194      }
195    };
196    table
197        .<AggregateService, AggregateResponse> coprocessorService(AggregateService::newStub,
198          (stub, controller, rpcCallback) -> stub.getMin(controller, req, rpcCallback), callback)
199        .fromRow(nullToEmpty(scan.getStartRow()), scan.includeStartRow())
200        .toRow(nullToEmpty(scan.getStopRow()), scan.includeStopRow()).execute();
201    return future;
202  }
203
204  public static <R, S, P extends Message, Q extends Message, T extends Message>
205      CompletableFuture<Long> rowCount(AsyncTable<?> table, ColumnInterpreter<R, S, P, Q, T> ci,
206          Scan scan) {
207    CompletableFuture<Long> future = new CompletableFuture<>();
208    AggregateRequest req;
209    try {
210      req = validateArgAndGetPB(scan, ci, true);
211    } catch (IOException e) {
212      future.completeExceptionally(e);
213      return future;
214    }
215    AbstractAggregationCallback<Long> callback = new AbstractAggregationCallback<Long>(future) {
216
217      private long count;
218
219      @Override
220      protected void aggregate(RegionInfo region, AggregateResponse resp) throws IOException {
221        count += resp.getFirstPart(0).asReadOnlyByteBuffer().getLong();
222      }
223
224      @Override
225      protected Long getFinalResult() {
226        return count;
227      }
228    };
229    table
230        .<AggregateService, AggregateResponse> coprocessorService(AggregateService::newStub,
231          (stub, controller, rpcCallback) -> stub.getRowNum(controller, req, rpcCallback), callback)
232        .fromRow(nullToEmpty(scan.getStartRow()), scan.includeStartRow())
233        .toRow(nullToEmpty(scan.getStopRow()), scan.includeStopRow()).execute();
234    return future;
235  }
236
237  public static <R, S, P extends Message, Q extends Message, T extends Message> CompletableFuture<S>
238      sum(AsyncTable<?> table, ColumnInterpreter<R, S, P, Q, T> ci, Scan scan) {
239    CompletableFuture<S> future = new CompletableFuture<>();
240    AggregateRequest req;
241    try {
242      req = validateArgAndGetPB(scan, ci, false);
243    } catch (IOException e) {
244      future.completeExceptionally(e);
245      return future;
246    }
247    AbstractAggregationCallback<S> callback = new AbstractAggregationCallback<S>(future) {
248      private S sum;
249
250      @Override
251      protected void aggregate(RegionInfo region, AggregateResponse resp) throws IOException {
252        if (resp.getFirstPartCount() > 0) {
253          S s = getPromotedValueFromProto(ci, resp, 0);
254          sum = ci.add(sum, s);
255        }
256      }
257
258      @Override
259      protected S getFinalResult() {
260        return sum;
261      }
262    };
263    table
264        .<AggregateService, AggregateResponse> coprocessorService(AggregateService::newStub,
265          (stub, controller, rpcCallback) -> stub.getSum(controller, req, rpcCallback), callback)
266        .fromRow(nullToEmpty(scan.getStartRow()), scan.includeStartRow())
267        .toRow(nullToEmpty(scan.getStopRow()), scan.includeStopRow()).execute();
268    return future;
269  }
270
271  public static <R, S, P extends Message, Q extends Message, T extends Message>
272      CompletableFuture<Double> avg(AsyncTable<?> table, ColumnInterpreter<R, S, P, Q, T> ci,
273          Scan scan) {
274    CompletableFuture<Double> future = new CompletableFuture<>();
275    AggregateRequest req;
276    try {
277      req = validateArgAndGetPB(scan, ci, false);
278    } catch (IOException e) {
279      future.completeExceptionally(e);
280      return future;
281    }
282    AbstractAggregationCallback<Double> callback = new AbstractAggregationCallback<Double>(future) {
283      private S sum;
284
285      long count = 0L;
286
287      @Override
288      protected void aggregate(RegionInfo region, AggregateResponse resp) throws IOException {
289        if (resp.getFirstPartCount() > 0) {
290          sum = ci.add(sum, getPromotedValueFromProto(ci, resp, 0));
291          count += resp.getSecondPart().asReadOnlyByteBuffer().getLong();
292        }
293      }
294
295      @Override
296      protected Double getFinalResult() {
297        return ci.divideForAvg(sum, count);
298      }
299    };
300    table
301        .<AggregateService, AggregateResponse> coprocessorService(AggregateService::newStub,
302          (stub, controller, rpcCallback) -> stub.getAvg(controller, req, rpcCallback), callback)
303        .fromRow(nullToEmpty(scan.getStartRow()), scan.includeStartRow())
304        .toRow(nullToEmpty(scan.getStopRow()), scan.includeStopRow()).execute();
305    return future;
306  }
307
308  public static <R, S, P extends Message, Q extends Message, T extends Message>
309      CompletableFuture<Double> std(AsyncTable<?> table, ColumnInterpreter<R, S, P, Q, T> ci,
310          Scan scan) {
311    CompletableFuture<Double> future = new CompletableFuture<>();
312    AggregateRequest req;
313    try {
314      req = validateArgAndGetPB(scan, ci, false);
315    } catch (IOException e) {
316      future.completeExceptionally(e);
317      return future;
318    }
319    AbstractAggregationCallback<Double> callback = new AbstractAggregationCallback<Double>(future) {
320
321      private S sum;
322
323      private S sumSq;
324
325      private long count;
326
327      @Override
328      protected void aggregate(RegionInfo region, AggregateResponse resp) throws IOException {
329        if (resp.getFirstPartCount() > 0) {
330          sum = ci.add(sum, getPromotedValueFromProto(ci, resp, 0));
331          sumSq = ci.add(sumSq, getPromotedValueFromProto(ci, resp, 1));
332          count += resp.getSecondPart().asReadOnlyByteBuffer().getLong();
333        }
334      }
335
336      @Override
337      protected Double getFinalResult() {
338        double avg = ci.divideForAvg(sum, count);
339        double avgSq = ci.divideForAvg(sumSq, count);
340        return Math.sqrt(avgSq - avg * avg);
341      }
342    };
343    table
344        .<AggregateService, AggregateResponse> coprocessorService(AggregateService::newStub,
345          (stub, controller, rpcCallback) -> stub.getStd(controller, req, rpcCallback), callback)
346        .fromRow(nullToEmpty(scan.getStartRow()), scan.includeStartRow())
347        .toRow(nullToEmpty(scan.getStopRow()), scan.includeStopRow()).execute();
348    return future;
349  }
350
351  // the map key is the startRow of the region
352  private static <R, S, P extends Message, Q extends Message, T extends Message>
353      CompletableFuture<NavigableMap<byte[], S>>
354      sumByRegion(AsyncTable<?> table, ColumnInterpreter<R, S, P, Q, T> ci, Scan scan) {
355    CompletableFuture<NavigableMap<byte[], S>> future =
356        new CompletableFuture<NavigableMap<byte[], S>>();
357    AggregateRequest req;
358    try {
359      req = validateArgAndGetPB(scan, ci, false);
360    } catch (IOException e) {
361      future.completeExceptionally(e);
362      return future;
363    }
364    int firstPartIndex = scan.getFamilyMap().get(scan.getFamilies()[0]).size() - 1;
365    AbstractAggregationCallback<NavigableMap<byte[], S>> callback =
366        new AbstractAggregationCallback<NavigableMap<byte[], S>>(future) {
367
368      private final NavigableMap<byte[], S> map = new TreeMap<>(Bytes.BYTES_COMPARATOR);
369
370        @Override
371        protected void aggregate(RegionInfo region, AggregateResponse resp) throws IOException {
372          if (resp.getFirstPartCount() > 0) {
373            map.put(region.getStartKey(), getPromotedValueFromProto(ci, resp, firstPartIndex));
374          }
375        }
376
377        @Override
378        protected NavigableMap<byte[], S> getFinalResult() {
379          return map;
380        }
381      };
382    table
383        .<AggregateService, AggregateResponse> coprocessorService(AggregateService::newStub,
384          (stub, controller, rpcCallback) -> stub.getMedian(controller, req, rpcCallback), callback)
385        .fromRow(nullToEmpty(scan.getStartRow()), scan.includeStartRow())
386        .toRow(nullToEmpty(scan.getStopRow()), scan.includeStopRow()).execute();
387    return future;
388  }
389
390  private static <R, S, P extends Message, Q extends Message, T extends Message> void findMedian(
391          CompletableFuture<R> future, AsyncTable<AdvancedScanResultConsumer> table,
392          ColumnInterpreter<R, S, P, Q, T> ci, Scan scan, NavigableMap<byte[], S> sumByRegion) {
393    double halfSum = ci.divideForAvg(sumByRegion.values().stream().reduce(ci::add).get(), 2L);
394    S movingSum = null;
395    byte[] startRow = null;
396    for (Map.Entry<byte[], S> entry : sumByRegion.entrySet()) {
397      startRow = entry.getKey();
398      S newMovingSum = ci.add(movingSum, entry.getValue());
399      if (ci.divideForAvg(newMovingSum, 1L) > halfSum) {
400        break;
401      }
402      movingSum = newMovingSum;
403    }
404    if (startRow != null) {
405      scan.withStartRow(startRow);
406    }
407    // we can not pass movingSum directly to an anonymous class as it is not final.
408    S baseSum = movingSum;
409    byte[] family = scan.getFamilies()[0];
410    NavigableSet<byte[]> qualifiers = scan.getFamilyMap().get(family);
411    byte[] weightQualifier = qualifiers.last();
412    byte[] valueQualifier = qualifiers.first();
413    table.scan(scan, new AdvancedScanResultConsumer() {
414      private S sum = baseSum;
415
416      private R value = null;
417
418      @Override
419      public void onNext(Result[] results, ScanController controller) {
420        try {
421          for (Result result : results) {
422            Cell weightCell = result.getColumnLatestCell(family, weightQualifier);
423            R weight = ci.getValue(family, weightQualifier, weightCell);
424            sum = ci.add(sum, ci.castToReturnType(weight));
425            if (ci.divideForAvg(sum, 1L) > halfSum) {
426              if (value != null) {
427                future.complete(value);
428              } else {
429                future.completeExceptionally(new NoSuchElementException());
430              }
431              controller.terminate();
432              return;
433            }
434            Cell valueCell = result.getColumnLatestCell(family, valueQualifier);
435            value = ci.getValue(family, valueQualifier, valueCell);
436          }
437        } catch (IOException e) {
438          future.completeExceptionally(e);
439          controller.terminate();
440        }
441      }
442
443      @Override
444      public void onError(Throwable error) {
445        future.completeExceptionally(error);
446      }
447
448      @Override
449      public void onComplete() {
450        if (!future.isDone()) {
451          // we should not reach here as the future should be completed in onNext.
452          future.completeExceptionally(new NoSuchElementException());
453        }
454      }
455    });
456  }
457
458  public static <R, S, P extends Message, Q extends Message, T extends Message>
459      CompletableFuture<R> median(AsyncTable<AdvancedScanResultConsumer> table,
460      ColumnInterpreter<R, S, P, Q, T> ci, Scan scan) {
461    CompletableFuture<R> future = new CompletableFuture<>();
462    addListener(sumByRegion(table, ci, scan), (sumByRegion, error) -> {
463      if (error != null) {
464        future.completeExceptionally(error);
465      } else if (sumByRegion.isEmpty()) {
466        future.completeExceptionally(new NoSuchElementException());
467      } else {
468        findMedian(future, table, ci, ReflectionUtils.newInstance(scan.getClass(), scan),
469          sumByRegion);
470      }
471    });
472    return future;
473  }
474}