import { IServerSideDatasource, IServerSideGetRowsParams } from "@ag-grid-community/core";
import { PageInfo } from "@cds-ui/data-access";
import { Maybe } from "graphql/jsutils/Maybe";
import { map, Observable, Subject, tap, mergeScan, range, AsyncSubject, takeUntil } from "rxjs";

export interface IPageableServerSideDataSource {
  observe$: Observable<PageState>;
  createHook: () => IServerSideDatasource;
}

export interface RelayConnection<T> {
  nodes?: Maybe<Array<T>>;
  pageInfo: PageInfo;
}

export interface RelayInput {
  first: number,
  after: string | null | undefined,
  last: number,
  before: string | null | undefined,
  sort: { [key: string]: "ASC" | "DESC" }[],
}

export interface PageState {
  startRow: number
  totalRows: number
  pageInfo: PageInfo | null
}

export abstract class RelayCursorDataSource<S, T extends RelayConnection<S>> implements IPageableServerSideDataSource {

  protected abstract dataSource(x: RelayInput): Observable<T>;

  private readonly signalSubject = new Subject<IServerSideGetRowsParams<any>>();
  private readonly signal$ = this.signalSubject.asObservable();

  private readonly destroySubject = new AsyncSubject<void>();
  private readonly destroy$ = this.destroySubject.asObservable();

  public readonly observe$ = this.signal$.pipe(
    takeUntil(this.destroy$),
    mergeScan(
      (acc, value) => this.onPageChanged(acc, value),
      { startRow: 0, totalRows: 0, pageInfo: null } as PageState,
      1
    ),
  );

  public createHook(): IServerSideDatasource {
    return new AgGridDataSource(this.signalSubject, this.destroySubject);
  }

  protected populateRowData(connection: T, params: IServerSideGetRowsParams<any>): any[] {
    return connection.nodes ?? [];
  }

  private onPageChanged(state: PageState, params: IServerSideGetRowsParams<any>): Observable<PageState> {
    const pageSize = (params.request.endRow ?? 0) - (params.request.startRow ?? 0);
    const relayInput: RelayInput = {
      first: pageSize,
      after: null,
      last: pageSize,
      before: null,
      sort: params.request.sortModel.map(x => ({ [x.colId]: x.sort == "desc" ? "DESC" : "ASC" })),
    };

    let source: Observable<T>;

    if (params.request.startRow == 0)
      source = this.goToFirstPage(relayInput);
    else if (params.request.startRow != null && params.request.startRow < state.startRow)
      source = this.rollBackward(state, params.request.startRow, pageSize, relayInput);
    else
      source = this.rollForward(state, params.request.startRow ?? 0, pageSize, relayInput);

    return source.pipe(
      map(x => this.onDraw(state, params, x)),
      tap({ error: e => this.onError(e, params) }),
    );
  }

  private goToFirstPage(relayInput: RelayInput): Observable<T> {
    return this.dataSource(relayInput);
  }

  private rollBackward(state: PageState, startRow: number, pageSize: number, relayInput: RelayInput): Observable<T> {
    return range(0, (state.startRow - startRow) / pageSize).pipe(
      mergeScan((acc, value) => this.dataSource(
        {
          ...relayInput,
          before: acc.pageInfo.startCursor,
        }),
        { pageInfo: state.pageInfo } as T,
        1
      )
    );
  }

  private rollForward(state: PageState, startRow: number, pageSize: number, relayInput: RelayInput): Observable<T> {
    return range(0, (startRow - state.startRow) / pageSize).pipe(
      mergeScan((acc, value) => this.dataSource(
        {
          ...relayInput,
          after: acc.pageInfo.endCursor,
        }),
        { pageInfo: state.pageInfo } as T,
        1
      )
    );
  }

  public onDraw(oldState: PageState, params: IServerSideGetRowsParams<any>, result: T): PageState {
    const startRow = params.request.startRow ?? 0;
    const endRow = params.request.endRow ?? 0;
    const isEof = result.nodes != null && result.nodes.length < endRow - startRow;
    const isNewPage = startRow >= oldState.startRow && oldState.totalRows <= startRow;
    const totalRows = isNewPage ? oldState.totalRows + (result.nodes?.length ?? 0) : oldState.totalRows;
    params.success(
      {
        rowData: this.populateRowData(result, params),
        rowCount: isEof ? totalRows : -1,
      }
    );
    return { startRow: startRow, totalRows: totalRows, pageInfo: result.pageInfo ?? null }
  }

  private onError(error: any, params: IServerSideGetRowsParams<any>) {
    console.log(error);
    params.fail();
  }
}

class AgGridDataSource<T> implements IServerSideDatasource {

  constructor(
    private signal: Subject<IServerSideGetRowsParams<T>>,
    private teardown: Subject<void>
  ) {
  }

  getRows(params: IServerSideGetRowsParams<T>): void {
    this.signal.next(params);
  }

  destroy(): void {
    this.teardown.next();
    this.teardown.complete();
  }
}
