React Example: Virtualized Columns

import React from 'react'
import ReactDOM from 'react-dom/client'
import './index.css'
import {
  columnSizingFeature,
  columnVisibilityFeature,
  createColumnHelper,
  createSortedRowModel,
  rowSortingFeature,
  sortFns,
  useTable,
} from '@tanstack/react-table'
import { useVirtualizer } from '@tanstack/react-virtual'
import { makeColumns, makeData } from './makeData'
import type {
  Cell,
  Header,
  HeaderGroup,
  ReactTable,
  Row,
} from '@tanstack/react-table'
import type { VirtualItem, Virtualizer } from '@tanstack/react-virtual'
import type { Person } from './makeData'

const features = {
  columnSizingFeature,
  columnVisibilityFeature,
  rowSortingFeature,
  sortedRowModel: createSortedRowModel(),
  sortFns,
}

const columnHelper = createColumnHelper<typeof features, Person>()

const DEFAULT_ROW_COUNT = 1_000
const DEFAULT_COLUMN_COUNT = 1_000
const STRESS_ROW_COUNT = 10_000
const STRESS_COLUMN_COUNT = 10_000

const makeTableColumns = (columnCount: number) =>
  columnHelper.columns(
    makeColumns(columnCount).map((column) =>
      columnHelper.accessor(column.accessorKey, {
        header: column.header,
        size: column.size,
      }),
    ),
  )

function App() {
  const [columns, setColumns] = React.useState(() =>
    makeTableColumns(DEFAULT_COLUMN_COUNT),
  )

  const [data, setData] = React.useState(() =>
    makeData(DEFAULT_ROW_COUNT, columns),
  )

  const refreshData = React.useCallback(() => {
    const nextColumns = makeTableColumns(DEFAULT_COLUMN_COUNT)
    setColumns(nextColumns)
    setData(makeData(DEFAULT_ROW_COUNT, nextColumns))
  }, [])

  const stressTestRows = React.useCallback(() => {
    setData(makeData(STRESS_ROW_COUNT, columns))
  }, [columns])

  const stressTestColumns = React.useCallback(() => {
    const nextColumns = makeTableColumns(STRESS_COLUMN_COUNT)
    setColumns(nextColumns)
    setData(makeData(data.length, nextColumns))
  }, [data.length])

  const table = useTable(
    {
      features,
      columns,
      data,
      debugTable: true,
    },
    (state) => state, // default selector
  )

  // All important CSS styles are included as inline styles for this example. This is not recommended for your code.
  return (
    <div className="app">
      {process.env.NODE_ENV === 'development' ? (
        <p>
          <strong>Notice:</strong> You are currently running React in
          development mode. Virtualized rendering performance will be slightly
          degraded until this application is built for production.
        </p>
      ) : null}
      <div>({columns.length.toLocaleString()} columns)</div>
      <div>({data.length.toLocaleString()} rows)</div>
      <div>
        <button onClick={refreshData}>Regenerate Data</button>
        <button onClick={stressTestRows}>Stress Test (10k rows)</button>
        <button onClick={stressTestColumns}>Stress Test (10k columns)</button>
      </div>
      <TableContainer table={table} />
    </div>
  )
}

interface TableContainerProps {
  table: ReactTable<typeof features, Person>
}

function TableContainer({ table }: TableContainerProps) {
  const visibleColumns = table.getVisibleLeafColumns()

  // The virtualizers need to know the scrollable container element
  const tableContainerRef = React.useRef<HTMLDivElement>(null)

  // we are using a slightly different virtualization strategy for columns (compared to virtual rows) in order to support dynamic row heights
  const columnVirtualizer = useVirtualizer<
    HTMLDivElement,
    HTMLTableCellElement
  >({
    count: visibleColumns.length,
    estimateSize: (index) => visibleColumns[index].getSize(), // estimate width of each column for accurate scrollbar dragging
    getScrollElement: () => tableContainerRef.current,
    horizontal: true,
    overscan: 3, // how many columns to render on each side off screen each way (adjust this for performance)
  })

  const virtualColumns = columnVirtualizer.getVirtualItems()

  // different virtualization strategy for columns - instead of absolute and translateY, we add empty columns to the left and right
  let virtualPaddingLeft: number | undefined
  let virtualPaddingRight: number | undefined

  if (virtualColumns.length) {
    virtualPaddingLeft = virtualColumns[0]?.start ?? 0
    virtualPaddingRight =
      columnVirtualizer.getTotalSize() -
      (virtualColumns[virtualColumns.length - 1]?.end ?? 0)
  }

  return (
    <div
      className="container"
      ref={tableContainerRef}
      style={{
        overflow: 'auto', // our scrollable table container
        position: 'relative', // needed for sticky header
        height: '800px', // should be a fixed height
      }}
    >
      {/* Even though we're still using sematic table tags, we must use CSS grid and flexbox for dynamic row heights */}
      <table style={{ display: 'grid' }}>
        <TableHead
          columnVirtualizer={columnVirtualizer}
          table={table}
          virtualPaddingLeft={virtualPaddingLeft}
          virtualPaddingRight={virtualPaddingRight}
        />
        <TableBody
          columnVirtualizer={columnVirtualizer}
          table={table}
          tableContainerRef={tableContainerRef}
          virtualPaddingLeft={virtualPaddingLeft}
          virtualPaddingRight={virtualPaddingRight}
        />
      </table>
    </div>
  )
}

interface TableHeadProps {
  columnVirtualizer: Virtualizer<HTMLDivElement, HTMLTableCellElement>
  table: ReactTable<typeof features, Person>
  virtualPaddingLeft: number | undefined
  virtualPaddingRight: number | undefined
}

function TableHead({
  columnVirtualizer,
  table,
  virtualPaddingLeft,
  virtualPaddingRight,
}: TableHeadProps) {
  return (
    <thead
      style={{
        display: 'grid',
        position: 'sticky',
        top: 0,
        zIndex: 1,
      }}
    >
      {table.getHeaderGroups().map((headerGroup) => (
        <TableHeadRow
          columnVirtualizer={columnVirtualizer}
          headerGroup={headerGroup}
          key={headerGroup.id}
          virtualPaddingLeft={virtualPaddingLeft}
          virtualPaddingRight={virtualPaddingRight}
          table={table}
        />
      ))}
    </thead>
  )
}

interface TableHeadRowProps {
  columnVirtualizer: Virtualizer<HTMLDivElement, HTMLTableCellElement>
  headerGroup: HeaderGroup<typeof features, Person>
  virtualPaddingLeft: number | undefined
  virtualPaddingRight: number | undefined
  table: ReactTable<typeof features, Person>
}

function TableHeadRow({
  columnVirtualizer,
  headerGroup,
  virtualPaddingLeft,
  virtualPaddingRight,
  table,
}: TableHeadRowProps) {
  const virtualColumns = columnVirtualizer.getVirtualItems()
  return (
    <tr key={headerGroup.id} style={{ display: 'flex', width: '100%' }}>
      {virtualPaddingLeft ? (
        // fake empty column to the left for virtualization scroll padding
        <th style={{ display: 'flex', width: virtualPaddingLeft }} />
      ) : null}
      {virtualColumns.map((virtualColumn) => {
        const header = headerGroup.headers[virtualColumn.index]
        return <TableHeadCell key={header.id} header={header} table={table} />
      })}
      {virtualPaddingRight ? (
        // fake empty column to the right for virtualization scroll padding
        <th style={{ display: 'flex', width: virtualPaddingRight }} />
      ) : null}
    </tr>
  )
}

interface TableHeadCellProps {
  header: Header<typeof features, Person, unknown>
  table: ReactTable<typeof features, Person>
}

function TableHeadCell({ header, table }: TableHeadCellProps) {
  return (
    <th
      key={header.id}
      style={{
        display: 'flex',
        width: header.getSize(),
      }}
    >
      <div
        {...{
          className: header.column.getCanSort() ? 'sortable-header' : '',
          onClick: header.column.getToggleSortingHandler(),
        }}
      >
        <table.FlexRender header={header} />
        {{
          asc: ' 🔼',
          desc: ' 🔽',
        }[header.column.getIsSorted() as string] ?? null}
      </div>
    </th>
  )
}

interface TableBodyProps {
  columnVirtualizer: Virtualizer<HTMLDivElement, HTMLTableCellElement>
  table: ReactTable<typeof features, Person>
  tableContainerRef: React.RefObject<HTMLDivElement | null>
  virtualPaddingLeft: number | undefined
  virtualPaddingRight: number | undefined
}

function TableBody({
  columnVirtualizer,
  table,
  tableContainerRef,
  virtualPaddingLeft,
  virtualPaddingRight,
}: TableBodyProps) {
  const { rows } = table.getRowModel()

  // dynamic row height virtualization - alternatively you could use a simpler fixed row height strategy without the need for `measureElement`
  const rowVirtualizer = useVirtualizer<HTMLDivElement, HTMLTableRowElement>({
    count: rows.length,
    estimateSize: () => 33, // estimate row height for accurate scrollbar dragging
    getScrollElement: () => tableContainerRef.current,
    // measure dynamic row height, except in firefox because it measures table border height incorrectly
    measureElement:
      typeof window !== 'undefined' &&
      navigator.userAgent.indexOf('Firefox') === -1
        ? (element) => element.getBoundingClientRect().height
        : undefined,
    overscan: 5,
  })

  const virtualRows = rowVirtualizer.getVirtualItems()

  return (
    <tbody
      style={{
        display: 'grid',
        height: `${rowVirtualizer.getTotalSize()}px`, // tells scrollbar how big the table is
        position: 'relative', // needed for absolute positioning of rows
      }}
    >
      {virtualRows.map((virtualRow) => {
        const row = rows[virtualRow.index]

        return (
          <TableBodyRow
            columnVirtualizer={columnVirtualizer}
            key={row.id}
            row={row}
            rowVirtualizer={rowVirtualizer}
            virtualPaddingLeft={virtualPaddingLeft}
            virtualPaddingRight={virtualPaddingRight}
            virtualRow={virtualRow}
            table={table}
          />
        )
      })}
    </tbody>
  )
}

interface TableBodyRowProps {
  columnVirtualizer: Virtualizer<HTMLDivElement, HTMLTableCellElement>
  row: Row<typeof features, Person>
  rowVirtualizer: Virtualizer<HTMLDivElement, HTMLTableRowElement>
  virtualPaddingLeft: number | undefined
  virtualPaddingRight: number | undefined
  virtualRow: VirtualItem
  table: ReactTable<typeof features, Person>
}

function TableBodyRow({
  columnVirtualizer,
  row,
  rowVirtualizer,
  virtualPaddingLeft,
  virtualPaddingRight,
  virtualRow,
  table,
}: TableBodyRowProps) {
  const visibleCells = row.getVisibleCells()
  const virtualColumns = columnVirtualizer.getVirtualItems()
  return (
    <tr
      data-index={virtualRow.index} // needed for dynamic row height measurement
      ref={(node) => rowVirtualizer.measureElement(node)} // measure dynamic row height
      key={row.id}
      style={{
        display: 'flex',
        position: 'absolute',
        transform: `translateY(${virtualRow.start}px)`, // this should always be a `style` as it changes on scroll
        width: '100%',
      }}
    >
      {virtualPaddingLeft ? (
        // fake empty column to the left for virtualization scroll padding
        <td style={{ display: 'flex', width: virtualPaddingLeft }} />
      ) : null}
      {virtualColumns.map((vc) => {
        const cell = visibleCells[vc.index]
        return <TableBodyCell key={cell.id} cell={cell} table={table} />
      })}
      {virtualPaddingRight ? (
        // fake empty column to the right for virtualization scroll padding
        <td style={{ display: 'flex', width: virtualPaddingRight }} />
      ) : null}
    </tr>
  )
}

interface TableBodyCellProps {
  cell: Cell<typeof features, Person, unknown>
  table: ReactTable<typeof features, Person>
}

function TableBodyCell({ cell, table }: TableBodyCellProps) {
  return (
    <td
      key={cell.id}
      style={{
        display: 'flex',
        width: cell.column.getSize(),
      }}
    >
      <table.FlexRender cell={cell} />
    </td>
  )
}

const rootElement = document.getElementById('root')

if (!rootElement) throw new Error('Failed to find the root element')

ReactDOM.createRoot(rootElement).render(
  <React.StrictMode>
    <App />
  </React.StrictMode>,
)