diff --git a/.github/workflows/array-api-tests-jax.yml b/.github/workflows/array-api-tests-jax.yml
new file mode 100644
index 00000000..59b70930
--- /dev/null
+++ b/.github/workflows/array-api-tests-jax.yml
@@ -0,0 +1,13 @@
+name: Array API Tests (JAX)
+
+on: [push, pull_request]
+
+jobs:
+  array-api-tests-jax:
+    uses: ./.github/workflows/array-api-tests.yml
+    with:
+      package-name: jax
+      # See https://github.com/google/jax/issues/22137 for reason behind skipped dtypes
+      extra-env-vars: |
+        JAX_ENABLE_X64=1
+        ARRAY_API_TESTS_SKIP_DTYPES=uint8,uint16,uint32,uint64
diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml
index 6e709438..a17514fd 100644
--- a/.github/workflows/array-api-tests.yml
+++ b/.github/workflows/array-api-tests.yml
@@ -33,7 +33,7 @@ on:
         description: "Multiline string of environment variables to set for the test run."
 
 env:
-  PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline"
+  PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} -k top_k --hypothesis-disable-deadline"
 
 jobs:
   tests:
@@ -50,9 +50,10 @@ jobs:
     - name: Checkout array-api-tests
       uses: actions/checkout@v4
       with:
-        repository: data-apis/array-api-tests
+        repository: JuliaPoo/array-api-tests
         submodules: 'true'
         path: array-api-tests
+        ref: ci-wip-topk-tests
     - name: Set up Python ${{ matrix.python-version }}
       uses: actions/setup-python@v5
       with:
@@ -77,6 +78,7 @@ jobs:
         # This enables the NEP 50 type promotion behavior (without it a lot of
         # tests fail on bad scalar type promotion behavior)
         NPY_PROMOTION_STATE: weak
+        ARRAY_API_TESTS_VERSION: draft
       run: |
         export PYTHONPATH="${GITHUB_WORKSPACE}/array-api-compat"
         cd ${GITHUB_WORKSPACE}/array-api-tests
diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py
index d2aac8b2..9a4b897d 100644
--- a/array_api_compat/dask/array/_aliases.py
+++ b/array_api_compat/dask/array/_aliases.py
@@ -150,6 +150,28 @@ def asarray(
 
     return da.asarray(obj, dtype=dtype, **kwargs)
 
+
+def top_k(
+    x: Array,
+    k: int,
+    /,
+    axis: Optional[int] = None,
+    *,
+    largest: bool = True,
+) -> tuple[Array, Array]:
+
+    if not largest:
+        k = -k
+
+    # For now, perform the computation twice,
+    # since an equivalent to numpy's `take_along_axis`
+    # does not exist.
+    # See https://github.com/dask/dask/issues/3663.
+    args = da.argtopk(x, k, axis=axis).compute()
+    vals = da.topk(x, k, axis=axis).compute()
+    return vals, args
+
+
 from dask.array import (
     # Element wise aliases
     arccos as acos,
@@ -178,6 +200,7 @@ def asarray(
                             'bitwise_right_shift', 'concat', 'pow',
                             'e', 'inf', 'nan', 'pi', 'newaxis', 'float32', 'float64', 'int8',
                             'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
-                            'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
+                            'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type',
+                            'top_k']
 
 _all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']
diff --git a/array_api_compat/jax/__init__.py b/array_api_compat/jax/__init__.py
new file mode 100644
index 00000000..4282af15
--- /dev/null
+++ b/array_api_compat/jax/__init__.py
@@ -0,0 +1,85 @@
+from jax.numpy import (
+    # Constants
+    e,
+    inf,
+    nan,
+    pi,
+    newaxis,
+    # Dtypes
+    bool,
+    float32,
+    float64,
+    int8,
+    int16,
+    int32,
+    int64,
+    uint8,
+    uint16,
+    uint32,
+    uint64,
+    complex64,
+    complex128,
+    iinfo,
+    finfo,
+    can_cast,
+    result_type,
+    # functions
+    zeros,
+    all,
+    any,
+    isnan,
+    isfinite,
+    reshape
+)
+from jax.numpy import (
+    asarray,
+    s_,
+    int_,
+    argpartition,
+    take_along_axis
+)
+
+
+def top_k(
+    x,
+    k,
+    /,
+    axis=None,
+    *,
+    largest=True,
+):
+    # The largest keyword can't be implemented with `jax.lax.top_k`
+    # efficiently so am using `jax.numpy` for now
+    if k <= 0:
+        raise ValueError(f'k(={k}) provided must be positive.')
+
+    positive_axis: int
+    _arr = asarray(x)
+    if axis is None:
+        arr = _arr.ravel()
+        positive_axis = 0
+    else:
+        arr = _arr
+        positive_axis = axis if axis > 0 else axis % arr.ndim
+
+    slice_start = (s_[:],) * positive_axis
+    if largest:
+        indices_array = argpartition(arr, -k, axis=axis)
+        slice = slice_start + (s_[-k:],)
+        topk_indices = indices_array[slice]
+    else:
+        indices_array = argpartition(arr, k-1, axis=axis)
+        slice = slice_start + (s_[:k],)
+        topk_indices = indices_array[slice]
+
+    topk_indices = topk_indices.astype(int_)
+    topk_values = take_along_axis(arr, topk_indices, axis=axis)
+    return (topk_values, topk_indices)
+
+
+__all__ = ['top_k', 'e', 'inf', 'nan', 'pi', 'newaxis', 'bool',
+           'float32', 'float64', 'int8', 'int16', 'int32',
+           'int64', 'uint8', 'uint16', 'uint32', 'uint64',
+           'complex64', 'complex128', 'iinfo', 'finfo',
+           'can_cast', 'result_type', 'zeros', 'all', 'isnan',
+           'isfinite', 'reshape', 'any']
diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py
index 70378716..ae28dac8 100644
--- a/array_api_compat/numpy/_aliases.py
+++ b/array_api_compat/numpy/_aliases.py
@@ -61,6 +61,35 @@
 matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
 tensordot = get_xp(np)(_aliases.tensordot)
 
+
+def top_k(a, k, /, axis=-1, *, largest=True):
+    if k <= 0:
+        raise ValueError(f'k(={k}) provided must be positive.')
+
+    positive_axis: int
+    _arr = np.asanyarray(a)
+    if axis is None:
+        arr = _arr.ravel()
+        positive_axis = 0
+    else:
+        arr = _arr
+        positive_axis = axis if axis > 0 else axis % arr.ndim
+
+    slice_start = (np.s_[:],) * positive_axis
+    if largest:
+        indices_array = np.argpartition(arr, -k, axis=axis)
+        slice = slice_start + (np.s_[-k:],)
+        topk_indices = indices_array[slice]
+    else:
+        indices_array = np.argpartition(arr, k-1, axis=axis)
+        slice = slice_start + (np.s_[:k],)
+        topk_indices = indices_array[slice]
+
+    topk_values = np.take_along_axis(arr, topk_indices, axis=axis)
+
+    return (topk_values, topk_indices)
+
+
 def _supports_buffer_protocol(obj):
     try:
         memoryview(obj)
@@ -126,6 +155,6 @@ def asarray(
 __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
                               'acosh', 'asin', 'asinh', 'atan', 'atan2',
                               'atanh', 'bitwise_left_shift', 'bitwise_invert',
-                              'bitwise_right_shift', 'concat', 'pow']
+                              'bitwise_right_shift', 'concat', 'pow', 'top_k']
 
 _all_ignore = ['np', 'get_xp']
diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py
index fb53e0ee..603dc15e 100644
--- a/array_api_compat/torch/_aliases.py
+++ b/array_api_compat/torch/_aliases.py
@@ -700,6 +700,8 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
         axis = 0
     return torch.index_select(x, axis, indices, **kwargs)
 
+top_k = torch.topk
+
 __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
            'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
            'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
@@ -713,6 +715,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
            'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
            'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
            'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
-           'take']
+           'take', 'top_k']
 
 _all_ignore = ['torch', 'get_xp']
diff --git a/jax-skips.txt b/jax-skips.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/jax-xfails.txt b/jax-xfails.txt
new file mode 100644
index 00000000..e69de29b