Compare commits
368 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8ab4d26474 | |||
| 8acba0ccff | |||
| 6f5544e0e4 | |||
| 4e44716b0c | |||
| fda71dadcb | |||
| 618c6dcaa4 | |||
| ae9d7f6b4c | |||
| 18c4368571 | |||
| 5325bec26c | |||
| 4895e011df | |||
| 6b1b3a2037 | |||
| 9b5d627a55 | |||
| 29acd25b4e | |||
| d2ee18c14f | |||
| 2ba4a62a0d | |||
| dc3519e973 | |||
| ee2c29d520 | |||
| efdd5d5a0c | |||
| de5ba47557 | |||
| e456799f1a | |||
| 5b7d204b9d | |||
| 1515d755e1 | |||
| 7ec88bf841 | |||
| dd8ebf10db | |||
| a029817d3d | |||
| ccc008eb5e | |||
| d898737d6d | |||
| 19d7281daf | |||
| 88f7505fdf | |||
| bf0aca35fa | |||
| b1409831a3 | |||
| 94db085b51 | |||
| 4e57b9fbdc | |||
| a55186cd02 | |||
| 9c0cc65973 | |||
| 459ee4e66a | |||
| 574e5d37c7 | |||
| 0d0ea981da | |||
| 0fa8f528c2 | |||
| 47805643f7 | |||
| 2a1bfb3e44 | |||
| abf14d976a | |||
| 0f3221f9d0 | |||
| c13e68248b | |||
| 9dcbe753f4 | |||
| cc1602ad78 | |||
| c619138ece | |||
| 62357084ba | |||
| aefb477e21 | |||
| 443173c071 | |||
| 3cb2d52a08 | |||
| a70278e0e1 | |||
| b1a095e486 | |||
| a64731eea5 | |||
| 934777d9ca | |||
| 5411abb9c1 | |||
| b402c6aba8 | |||
| 8047a3ea61 | |||
| cf999f3e28 | |||
| 704840c04e | |||
| 5ca17c3f63 | |||
| 3120c94c22 | |||
| 2687e3db49 | |||
| d22996ea20 | |||
| 6bc03907bd | |||
| b1faaef482 | |||
| b50bb99fe7 | |||
| daa34cf7b8 | |||
| 85c679597c | |||
| cb54986d3f | |||
| 5e594adfba | |||
| eefc26c108 | |||
| dd5173b45c | |||
| c01910fb75 | |||
| 0ad8e775a5 | |||
| 3ad27b547f | |||
| 50966c4cf7 | |||
| 34f799257c | |||
| 257df81667 | |||
| 2b6586d542 | |||
| adcf8838d2 | |||
| e8e095e2f8 | |||
| 3cc77d96eb | |||
| 3049a56355 | |||
| 915bb41ea2 | |||
| 05670d133e | |||
| 32bb1e7ce9 | |||
| a89d6909b2 | |||
| f5df54831a | |||
| fe7c9f8ec1 | |||
| 9cf3e102ba | |||
| 3b15f13ae4 | |||
| 9b1ff43e9f | |||
| ea42212a2a | |||
| 0ebcb7de55 | |||
| d275331c13 | |||
| 29a2fe46e8 | |||
| 93b8121c9b | |||
| 1386465631 | |||
| bbe2baf3f6 | |||
| 3ad5e11d22 | |||
| 9a670b90df | |||
| 2a66395fb7 | |||
| 4f3958c831 | |||
| b65c555dfc | |||
| 8d14076a23 | |||
| 3759bb2a9a | |||
| 504cd462a7 | |||
| 8940ea179e | |||
| 587017665a | |||
| 06d7e368ab | |||
| f2952000d9 | |||
| a6bb3b29d0 | |||
| db7030716d | |||
| 45c05a0896 | |||
| ffbaa93722 | |||
| 18b282cabb | |||
| 78283cf236 | |||
| d165d76338 | |||
| ce953441fb | |||
| cd4ab97efa | |||
| 6325a9ea91 | |||
| a1056bfa2a | |||
| bf63cc929a | |||
| 1d88b9c65c | |||
| 738a38d71f | |||
| 9bc0d06aa0 | |||
| aa3812ff4e | |||
| 15d7b78527 | |||
| cb62e16b41 | |||
| 46f194e7f1 | |||
| 0a95ba62b1 | |||
| 4f6355506c | |||
| df2649ed2a | |||
| d11d83cc98 | |||
| bbebc1a86a | |||
| 74cd31bdb1 | |||
| 88d49dbcab | |||
| c7aea2fc42 | |||
| 087a7defde | |||
| 951343aa06 | |||
| 2c74d974ca | |||
| 132a788c54 | |||
| f2051218ee | |||
| fc1536daab | |||
| bf0d530e78 | |||
| 093e3bb3d7 | |||
| f077e14b38 | |||
| 6c0552a5d4 | |||
| 9104a067d6 | |||
| 00d0620679 | |||
| 78a39a809d | |||
| 092a22f242 | |||
| 4919975f13 | |||
| 3ab8d57630 | |||
| d931b2c10d | |||
| 139bc6f58b | |||
| d8008de77a | |||
| 69c73b2d28 | |||
| f9b7588963 | |||
| c9bedc5e58 | |||
| 8c4de49359 | |||
| 4b540b7c42 | |||
| e49ef68ebc | |||
| 1755e97748 | |||
| d9a61dd4c8 | |||
| 776f287685 | |||
| 70d7dd9b2f | |||
| e6f568fcac | |||
| 028a4edbd4 | |||
| 6d2b7ea3ba | |||
| 9339d597b9 | |||
| 574635f43d | |||
| 0c75ea6286 | |||
| df7c7393ad | |||
| 2a7fe13397 | |||
| af502a6a66 | |||
| 6c83012082 | |||
| 518f6960d0 | |||
| d38cc75f31 | |||
| 31aaa1ed59 | |||
| 59cc4a2586 | |||
| 9775228b00 | |||
| 65ff604969 | |||
| fedb180735 | |||
| 21e6bea792 | |||
| 27c8345ef2 | |||
| 5a449bf86f | |||
| a7e08db16d | |||
| b6426083b9 | |||
| 4f453544d4 | |||
| 47a53ce6c5 | |||
| 20bcb04e8a | |||
| c86fc6e976 | |||
| 2f0d30d7b5 | |||
| 48c0b59447 | |||
| 39cf329404 | |||
| ee4b934601 | |||
| b8ec5c786d | |||
| c37ecdb9ff | |||
| 413bfb8d58 | |||
| 112eaf80d1 | |||
| 4054a9c7cb | |||
| 6571e52f17 | |||
| 28428d1294 | |||
| 3c215a83b6 | |||
| 266a3b24e7 | |||
| f9075cab0e | |||
| b64f624d17 | |||
| ea115c981d | |||
| 1c85799be5 | |||
| 15b9a59786 | |||
| 95aea104c7 | |||
| 4c8be34d81 | |||
| f160830226 | |||
| 38e2a28ada | |||
| 189c562826 | |||
| ee00a1d886 | |||
| 99013b3aed | |||
| 8cd5aeaf25 | |||
| 1214022c5a | |||
| 8738755ffc | |||
| 1e1967e0db | |||
| 7898581e50 | |||
| 6b365f46f5 | |||
| 2e30d0512e | |||
| 4183c5e1d0 | |||
| 6deef06ad2 | |||
| 80b45f1aa1 | |||
| a7ee8b31e0 | |||
| 9e099b543f | |||
| 5fd90471fc | |||
| 57c84d6446 | |||
| b77d6bdd91 | |||
| 764600003b | |||
| 7ad4276224 | |||
| 656dcc0050 | |||
| 5de6f86959 | |||
| 1bf2dc0cc3 | |||
| 5698b9d706 | |||
| 3db9ea9dd2 | |||
| 93475453d8 | |||
| ceef283bfd | |||
| d30945c5c5 | |||
| 0899548208 | |||
| eb71053e56 | |||
| 5e2efb68f1 | |||
| 50321ba2aa | |||
| bc47d7ce69 | |||
| 3618b098cb | |||
| 2ca7214259 | |||
| 8d7954b015 | |||
| 67230babc0 | |||
| 3993f66997 | |||
| db0ba8588e | |||
| 714c366d16 | |||
| 1186e643ec | |||
| 7fe7ffea6d | |||
| 72d6731924 | |||
| 153e96f574 | |||
| 794b88fab4 | |||
| 29d804e692 | |||
| adad347902 | |||
| 6f82ad09c8 | |||
| 353fb8724a | |||
| 3e4b67893e | |||
| 9196b3978d | |||
| 732bc5910c | |||
| 64e4ea73c0 | |||
| bf8d823ae3 | |||
| f314f30ebc | |||
| 36a599ea9a | |||
| 68ee82437e | |||
| d499416024 | |||
| b3d07ffd87 | |||
| 63fd4945a2 | |||
| b340634aaa | |||
| 1bca269b90 | |||
| 77acf0c340 | |||
| fc841898cd | |||
| 6e9c05f859 | |||
| 21664c5c58 | |||
| 9e12850f38 | |||
| 86fdafda23 | |||
| b2bc74e3af | |||
| 87ab6ae8a0 | |||
| b8bd3208ca | |||
| 9e9a9e0cd2 | |||
| 40c0fc285c | |||
| b78ab9e028 | |||
| 938bd7341b | |||
| 45f39ba488 | |||
| e847e7386a | |||
| ec453f01e4 | |||
| 22e49c4316 | |||
| 62d97b18f4 | |||
| a01ab27751 | |||
| b20ecfdf37 | |||
| b6712ffbee | |||
| 4f0417c6ad | |||
| 0f8c2f592e | |||
| 478d49c19c | |||
| 0552c36e29 | |||
| 9b5ee8f267 | |||
| 9ab437d6e2 | |||
| 99a7a8dd22 | |||
| f16dd5acb4 | |||
| d57c181aad | |||
| 3ded910cca | |||
| 214e59452f | |||
| 83c35bb916 | |||
| 21e8fb243b | |||
| 57c7fcf27f | |||
| 1ee1db9664 | |||
| a4980446c5 | |||
| 708bdbc134 | |||
| 850a83097c | |||
| a2098254cd | |||
| 846dd999b7 | |||
| 7e54413d3b | |||
| e9efb7e253 | |||
| 34a2d40f27 | |||
| 184e7dbce0 | |||
| 0e59cb21ce | |||
| 5c0d63d31f | |||
| d4f0a6fecf | |||
| 4db98b2b9f | |||
| cab6fe9482 | |||
| edec39baef | |||
| a7a56f9a26 | |||
| 0551a6cba2 | |||
| 42d1b5e4ba | |||
| 4f78368403 | |||
| 31f25002a6 | |||
| 2b8223bdd5 | |||
| 07e2565a4f | |||
| 761f1e7c1a | |||
| 09da3858ce | |||
| b4c29f34c3 | |||
| d0b02e581d | |||
| 66ad86a755 | |||
| 43f368dfc4 | |||
| e5e1ed2f9c | |||
| 067069d2e2 | |||
| 5b5bc1da56 | |||
| 6e20f9c729 | |||
| 9e148a5cac | |||
| f5bbbdf638 | |||
| 522fde47dc | |||
| 29bac36816 | |||
| 442df9e132 | |||
| 849e389388 | |||
| 20d950d1b3 | |||
| ba6a868a80 | |||
| ce211fd8f5 | |||
| 8a94b72c7d | |||
| f6aa025a01 | |||
| 346583f13e | |||
| abb804f2de | |||
| d380c9494d | |||
| 4e26e325a6 | |||
| c026464375 | |||
| 3610f09c77 | |||
| d38e645492 | |||
| 9c5b879b16 | |||
| 2c41343ce5 | |||
| 7dc73ed6c6 | |||
| 5e04a2f800 |
@@ -0,0 +1,83 @@
|
||||
FROM ubuntu
|
||||
SHELL ["/bin/bash", "-o", "pipefail", "-c"]
|
||||
|
||||
ENV EDITOR=vim
|
||||
|
||||
RUN apt-get update && apt-get upgrade
|
||||
|
||||
RUN apt-get install --yes \
|
||||
ca-certificates \
|
||||
bash-completion \
|
||||
build-essential \
|
||||
curl \
|
||||
cmake \
|
||||
direnv \
|
||||
emacs-nox \
|
||||
gnupg \
|
||||
htop \
|
||||
jq \
|
||||
less \
|
||||
lsb-release \
|
||||
lsof \
|
||||
man-db \
|
||||
nano \
|
||||
neovim \
|
||||
ssl-cert \
|
||||
sudo \
|
||||
unzip \
|
||||
xz-utils \
|
||||
zip
|
||||
|
||||
# configure locales to UTF8
|
||||
RUN apt-get install locales && locale-gen en_US.UTF-8
|
||||
ENV LANG='en_US.UTF-8' LANGUAGE='en_US:en' LC_ALL='en_US.UTF-8'
|
||||
|
||||
# configure direnv
|
||||
RUN direnv hook bash >> $HOME/.bashrc
|
||||
|
||||
# install nix
|
||||
RUN sh <(curl -L https://nixos.org/nix/install) --daemon
|
||||
|
||||
RUN mkdir -p $HOME/.config/nix $HOME/.config/nixpkgs \
|
||||
&& echo 'sandbox = false' >> $HOME/.config/nix/nix.conf \
|
||||
&& echo '{ allowUnfree = true; }' >> $HOME/.config/nixpkgs/config.nix \
|
||||
&& echo '. $HOME/.nix-profile/etc/profile.d/nix.sh' >> $HOME/.bashrc
|
||||
|
||||
|
||||
# install docker and configure daemon to use vfs as GitHub codespaces requires vfs
|
||||
# https://github.com/moby/moby/issues/13742#issuecomment-725197223
|
||||
RUN mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg \
|
||||
&& echo \
|
||||
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
|
||||
$(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null \
|
||||
&& apt-get update \
|
||||
&& apt-get install --yes docker-ce docker-ce-cli containerd.io docker-compose-plugin \
|
||||
&& mkdir -p /etc/docker \
|
||||
&& echo '{"cgroup-parent":"/actions_job","storage-driver":"vfs"}' >> /etc/docker/daemon.json
|
||||
|
||||
# install golang and language tooling
|
||||
ENV GO_VERSION=1.19
|
||||
ENV GOPATH=$HOME/go-packages
|
||||
ENV GOROOT=$HOME/go
|
||||
ENV PATH=$GOROOT/bin:$GOPATH/bin:$PATH
|
||||
RUN curl -fsSL https://dl.google.com/go/go$GO_VERSION.linux-amd64.tar.gz | tar xzs
|
||||
RUN echo 'export PATH=$GOPATH/bin:$PATH' >> $HOME/.bashrc
|
||||
|
||||
RUN bash -c ". $HOME/.bashrc \
|
||||
go install -v golang.org/x/tools/gopls@latest \
|
||||
&& go install -v mvdan.cc/sh/v3/cmd/shfmt@latest \
|
||||
"
|
||||
|
||||
# install nodejs
|
||||
RUN bash -c "$(curl -fsSL https://deb.nodesource.com/setup_14.x)" \
|
||||
&& apt-get install -y nodejs
|
||||
|
||||
# install zstd
|
||||
RUN bash -c "$(curl -fsSL https://raw.githubusercontent.com/horta/zstd.install/main/install)"
|
||||
|
||||
# install nfpm
|
||||
RUN echo 'deb [trusted=yes] https://repo.goreleaser.com/apt/ /' | sudo tee /etc/apt/sources.list.d/goreleaser.list \
|
||||
&& apt update \
|
||||
&& apt install nfpm
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
// For format details, see https://aka.ms/devcontainer.json
|
||||
{
|
||||
"name": "Development environments on your infrastructure",
|
||||
|
||||
// Sets the run context to one level up instead of the .devcontainer folder.
|
||||
"context": ".",
|
||||
|
||||
// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.
|
||||
"dockerFile": "Dockerfile",
|
||||
|
||||
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
||||
// "forwardPorts": [],
|
||||
|
||||
"postStartCommand": "dockerd",
|
||||
|
||||
// privileged is required by GitHub codespaces - https://github.com/microsoft/vscode-dev-containers/issues/727
|
||||
"runArgs": [ "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined", "--privileged", "--init" ]
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
name: "CLA Assistant"
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_target:
|
||||
types: [opened,closed,synchronize]
|
||||
|
||||
jobs:
|
||||
CLAssistant:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "CLA Assistant"
|
||||
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target'
|
||||
uses: contributor-assistant/github-action@v2.2.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# the below token should have repo scope and must be manually added by you in the repository's secret
|
||||
PERSONAL_ACCESS_TOKEN : ${{ secrets.CDRCOMMUNITY_GITHUB_TOKEN }}
|
||||
with:
|
||||
remote-organization-name: 'coder'
|
||||
remote-repository-name: 'cla'
|
||||
path-to-signatures: 'v2022-09-04/signatures.json'
|
||||
path-to-document: 'https://github.com/coder/cla/blob/main/README.md'
|
||||
# branch should not be protected
|
||||
branch: 'main'
|
||||
allowlist: dependabot*
|
||||
@@ -4,8 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
pull_request:
|
||||
|
||||
@@ -36,7 +34,7 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@master
|
||||
uses: crate-ci/typos@v1.12.8
|
||||
with:
|
||||
config: .github/workflows/typos.toml
|
||||
- name: Fix Helper
|
||||
@@ -344,15 +342,6 @@ jobs:
|
||||
fi
|
||||
gotestsum --junitfile="gotests.xml" --packages="./..." -- -parallel=8 -timeout=$test_timeout -short -failfast $COVERAGE_FLAGS
|
||||
|
||||
- name: Upload DataDog Trace
|
||||
if: github.actor != 'dependabot[bot]' && !github.event.pull_request.head.repo.fork
|
||||
env:
|
||||
DATADOG_API_KEY: ${{ secrets.DATADOG_API_KEY }}
|
||||
DD_DATABASE: fake
|
||||
DD_CATEGORY: unit
|
||||
GIT_COMMIT_MESSAGE: ${{ github.event.head_commit.message }}
|
||||
run: go run scripts/datadog-cireport/main.go gotests.xml
|
||||
|
||||
- uses: codecov/codecov-action@v3
|
||||
# This action has a tendency to error out unexpectedly, it has
|
||||
# the `fail_ci_if_error` option that defaults to `false`, but
|
||||
@@ -414,14 +403,6 @@ jobs:
|
||||
- name: Test with PostgreSQL Database
|
||||
run: make test-postgres
|
||||
|
||||
- name: Upload DataDog Trace
|
||||
if: always() && github.actor != 'dependabot[bot]' && !github.event.pull_request.head.repo.fork
|
||||
env:
|
||||
DATADOG_API_KEY: ${{ secrets.DATADOG_API_KEY }}
|
||||
DD_DATABASE: postgresql
|
||||
GIT_COMMIT_MESSAGE: ${{ github.event.head_commit.message }}
|
||||
run: go run scripts/datadog-cireport/main.go gotests.xml
|
||||
|
||||
- uses: codecov/codecov-action@v3
|
||||
# This action has a tendency to error out unexpectedly, it has
|
||||
# the `fail_ci_if_error` option that defaults to `false`, but
|
||||
@@ -506,6 +487,7 @@ jobs:
|
||||
go mod download
|
||||
|
||||
version="$(./scripts/version.sh)"
|
||||
make gen/mark-fresh
|
||||
make -j \
|
||||
build/coder_"$version"_windows_amd64.zip \
|
||||
build/coder_"$version"_linux_amd64.{tar.gz,deb}
|
||||
@@ -548,11 +530,6 @@ jobs:
|
||||
restore-keys: |
|
||||
js-${{ runner.os }}-
|
||||
|
||||
# Go is required for uploading the test results to datadog
|
||||
- uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: "~1.19"
|
||||
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: "14"
|
||||
@@ -575,14 +552,6 @@ jobs:
|
||||
files: ./site/coverage/lcov.info
|
||||
flags: unittest-js
|
||||
|
||||
- name: Upload DataDog Trace
|
||||
if: always() && github.actor != 'dependabot[bot]' && !github.event.pull_request.head.repo.fork
|
||||
env:
|
||||
DATADOG_API_KEY: ${{ secrets.DATADOG_API_KEY }}
|
||||
DD_CATEGORY: unit
|
||||
GIT_COMMIT_MESSAGE: ${{ github.event.head_commit.message }}
|
||||
run: go run scripts/datadog-cireport/main.go site/test-results/junit.xml
|
||||
|
||||
test-e2e:
|
||||
name: "test/e2e/${{ matrix.os }}"
|
||||
needs:
|
||||
@@ -606,7 +575,6 @@ jobs:
|
||||
.eslintcache
|
||||
key: js-${{ runner.os }}-e2e-${{ hashFiles('**/yarn.lock') }}
|
||||
|
||||
# Go is required for uploading the test results to datadog
|
||||
- uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: "~1.19"
|
||||
@@ -660,15 +628,8 @@ jobs:
|
||||
with:
|
||||
name: failed-test-videos
|
||||
path: ./site/test-results/**/*.webm
|
||||
retention:days: 7
|
||||
retention-days: 7
|
||||
|
||||
- name: Upload DataDog Trace
|
||||
if: always() && github.actor != 'dependabot[bot]' && runner.os == 'Linux' && !github.event.pull_request.head.repo.fork
|
||||
env:
|
||||
DATADOG_API_KEY: ${{ secrets.DATADOG_API_KEY }}
|
||||
DD_CATEGORY: e2e
|
||||
GIT_COMMIT_MESSAGE: ${{ github.event.head_commit.message }}
|
||||
run: go run scripts/datadog-cireport/main.go site/test-results/junit.xml
|
||||
chromatic:
|
||||
# REMARK: this is only used to build storybook and deploy it to Chromatic.
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
# Dependabot is annoying, but this makes it a bit less so.
|
||||
name: Auto Approve Dependabot
|
||||
|
||||
on: pull_request_target
|
||||
|
||||
jobs:
|
||||
auto-approve:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: hmarr/auto-approve-action@v2
|
||||
if: github.actor == 'dependabot[bot]'
|
||||
@@ -4,8 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- "*"
|
||||
paths:
|
||||
- "dogfood/**"
|
||||
pull_request:
|
||||
@@ -14,12 +12,12 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
deploy_image:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get branch name
|
||||
id: branch-name
|
||||
uses: tj-actions/branch-names@v5.4
|
||||
uses: tj-actions/branch-names@v6.1
|
||||
|
||||
- name: "Branch name to Docker tag name"
|
||||
id: docker-tag-name
|
||||
@@ -49,3 +47,27 @@ jobs:
|
||||
tags: "codercom/oss-dogfood:${{ steps.docker-tag-name.outputs.tag }},codercom/oss-dogfood:latest"
|
||||
cache-from: type=registry,ref=codercom/oss-dogfood:latest
|
||||
cache-to: type=inline
|
||||
deploy_template:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
- name: Get short commit SHA
|
||||
id: vars
|
||||
run: echo "::set-output name=sha_short::$(git rev-parse --short HEAD)"
|
||||
- name: "Install latest Coder"
|
||||
run: |
|
||||
curl -L https://coder.com/install.sh | sh
|
||||
# env:
|
||||
# VERSION: 0.x
|
||||
- name: "Push template"
|
||||
run: |
|
||||
coder templates push $CODER_TEMPLATE_NAME --directory $CODER_TEMPLATE_DIR --yes --name=$CODER_TEMPLATE_VERSION
|
||||
env:
|
||||
# Consumed by Coder CLI
|
||||
CODER_URL: https://dev.coder.com
|
||||
CODER_SESSION_TOKEN: ${{ secrets.CODER_SESSION_TOKEN }}
|
||||
# Template source & details
|
||||
CODER_TEMPLATE_NAME: ${{ secrets.CODER_TEMPLATE_NAME }}
|
||||
CODER_TEMPLATE_VERSION: ${{ steps.vars.outputs.sha_short }}
|
||||
CODER_TEMPLATE_DIR: ./dogfood
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
steps:
|
||||
# v5.1.0 has a weird bug that makes stalebot add then remove its own label
|
||||
# https://github.com/actions/stale/pull/775
|
||||
- uses: actions/stale@v5.0.0
|
||||
- uses: actions/stale@v6.0.0
|
||||
with:
|
||||
stale-issue-label: stale
|
||||
stale-pr-label: stale
|
||||
|
||||
@@ -5,6 +5,8 @@ IST = "IST"
|
||||
MacOS = "macOS"
|
||||
|
||||
[default.extend-words]
|
||||
# do as sudo replacement
|
||||
doas = "doas"
|
||||
|
||||
[files]
|
||||
extend-exclude = [
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
name: Welcome
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened]
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: wow-actions/welcome@v1
|
||||
with:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
FIRST_PR_REACTIONS: '+1, hooray, rocket, heart'
|
||||
FIRST_PR_COMMENT: |
|
||||
👋 Welcome @{{ author }} to Coder! Yo @coder/docs this is @{{ author }}'s first pull-request here!
|
||||
FIRST_PR_MERGED: |
|
||||
🎉 Thanks for the contribution @{{ author }}! Yo @coder/docs @{{ author }}'s first contribution has been merged! 👀👀👀
|
||||
+4
-2
@@ -15,6 +15,7 @@ vendor
|
||||
yarn-error.log
|
||||
gotests.coverage
|
||||
.idea
|
||||
.gitpod.yml
|
||||
.DS_Store
|
||||
|
||||
# Front-end ignore
|
||||
@@ -30,8 +31,8 @@ site/**/*.typegen.ts
|
||||
site/build-storybook.log
|
||||
|
||||
# Build
|
||||
build/
|
||||
dist/
|
||||
/build/
|
||||
/dist/
|
||||
site/out/
|
||||
|
||||
*.tfstate
|
||||
@@ -41,6 +42,7 @@ site/out/
|
||||
.terraform/
|
||||
|
||||
.vscode/*.log
|
||||
.vscode/launch.json
|
||||
**/*.swp
|
||||
.coderv2/*
|
||||
**/__debug_bin
|
||||
|
||||
Vendored
+28
-10
@@ -8,7 +8,9 @@
|
||||
"circbuf",
|
||||
"cliflag",
|
||||
"cliui",
|
||||
"codecov",
|
||||
"coderd",
|
||||
"coderdenttest",
|
||||
"coderdtest",
|
||||
"codersdk",
|
||||
"cronstrue",
|
||||
@@ -17,11 +19,14 @@
|
||||
"derphttp",
|
||||
"derpmap",
|
||||
"devel",
|
||||
"dflags",
|
||||
"drpc",
|
||||
"drpcconn",
|
||||
"drpcmux",
|
||||
"drpcserver",
|
||||
"Dsts",
|
||||
"enablements",
|
||||
"eventsourcemock",
|
||||
"fatih",
|
||||
"Formik",
|
||||
"gitsshkey",
|
||||
@@ -31,6 +36,7 @@
|
||||
"gonet",
|
||||
"gossh",
|
||||
"gsyslog",
|
||||
"GTTY",
|
||||
"hashicorp",
|
||||
"hclsyntax",
|
||||
"httpapi",
|
||||
@@ -67,6 +73,7 @@
|
||||
"ntqry",
|
||||
"OIDC",
|
||||
"oneof",
|
||||
"opty",
|
||||
"paralleltest",
|
||||
"parameterscopeid",
|
||||
"pqtype",
|
||||
@@ -76,10 +83,14 @@
|
||||
"provisionerd",
|
||||
"provisionersdk",
|
||||
"ptty",
|
||||
"ptys",
|
||||
"ptytest",
|
||||
"quickstart",
|
||||
"reconfig",
|
||||
"replicasync",
|
||||
"retrier",
|
||||
"rpty",
|
||||
"SCIM",
|
||||
"sdkproto",
|
||||
"sdktrace",
|
||||
"Signup",
|
||||
@@ -87,6 +98,7 @@
|
||||
"sourcemapped",
|
||||
"Srcs",
|
||||
"stretchr",
|
||||
"STTY",
|
||||
"stuntest",
|
||||
"tailbroker",
|
||||
"tailcfg",
|
||||
@@ -105,6 +117,7 @@
|
||||
"tfjson",
|
||||
"tfplan",
|
||||
"tfstate",
|
||||
"tios",
|
||||
"tparallel",
|
||||
"trimprefix",
|
||||
"tsdial",
|
||||
@@ -112,10 +125,12 @@
|
||||
"tstun",
|
||||
"turnconn",
|
||||
"typegen",
|
||||
"typesafe",
|
||||
"unconvert",
|
||||
"Untar",
|
||||
"Userspace",
|
||||
"VMID",
|
||||
"walkthrough",
|
||||
"weblinks",
|
||||
"webrtc",
|
||||
"wgcfg",
|
||||
@@ -135,6 +150,10 @@
|
||||
"xstate",
|
||||
"yamux"
|
||||
],
|
||||
"cSpell.ignorePaths": [
|
||||
"site/package.json",
|
||||
".vscode/settings.json"
|
||||
],
|
||||
"emeraldwalk.runonsave": {
|
||||
"commands": [
|
||||
{
|
||||
@@ -157,20 +176,19 @@
|
||||
"go.lintFlags": ["--fast"],
|
||||
"go.lintOnSave": "package",
|
||||
"go.coverOnSave": true,
|
||||
"go.coverageDecorator": {
|
||||
"type": "gutter",
|
||||
"coveredGutterStyle": "blockgreen",
|
||||
"uncoveredGutterStyle": "blockred"
|
||||
},
|
||||
// The codersdk is used by coderd another other packages extensively.
|
||||
// To reduce redundancy in tests, it's covered by other packages.
|
||||
// Since package coverage pairing can't be defined, all packages cover
|
||||
// all other packages.
|
||||
"go.testFlags": ["-short", "-coverpkg=./..."],
|
||||
"go.coverageDecorator": {
|
||||
"type": "gutter",
|
||||
"coveredHighlightColor": "rgba(64,128,128,0.5)",
|
||||
"uncoveredHighlightColor": "rgba(128,64,64,0.25)",
|
||||
"coveredBorderColor": "rgba(64,128,128,0.5)",
|
||||
"uncoveredBorderColor": "rgba(128,64,64,0.25)",
|
||||
"coveredGutterStyle": "blockgreen",
|
||||
"uncoveredGutterStyle": "blockred"
|
||||
},
|
||||
"go.testFlags": [
|
||||
"-short",
|
||||
"-coverpkg=./..."
|
||||
],
|
||||
// We often use a version of TypeScript that's ahead of the version shipped
|
||||
// with VS Code.
|
||||
"typescript.tsdk": "./site/node_modules/typescript/lib"
|
||||
|
||||
@@ -25,6 +25,7 @@ COPY --chown=coder:coder --chmod=700 empty-dir /home/coder
|
||||
|
||||
USER coder:coder
|
||||
ENV HOME=/home/coder
|
||||
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt
|
||||
WORKDIR /home/coder
|
||||
|
||||
ENTRYPOINT [ "/opt/coder", "server" ]
|
||||
|
||||
@@ -37,6 +37,13 @@ GOARCH := $(shell go env GOARCH)
|
||||
GOOS_BIN_EXT := $(if $(filter windows, $(GOOS)),.exe,)
|
||||
VERSION := $(shell ./scripts/version.sh)
|
||||
|
||||
# Use the highest ZSTD compression level in CI.
|
||||
ifdef CI
|
||||
ZSTDFLAGS := -22 --ultra
|
||||
else
|
||||
ZSTDFLAGS := -6
|
||||
endif
|
||||
|
||||
# All ${OS}_${ARCH} combos we build for. Windows binaries have the .exe suffix.
|
||||
OS_ARCHES := \
|
||||
linux_amd64 linux_arm64 linux_armv7 \
|
||||
@@ -102,9 +109,8 @@ build/coder-slim_$(VERSION).tar: build/coder-slim_$(VERSION)_checksums.sha1 $(CO
|
||||
popd
|
||||
|
||||
build/coder-slim_$(VERSION).tar.zst site/out/bin/coder.tar.zst: build/coder-slim_$(VERSION).tar
|
||||
zstd -6 \
|
||||
zstd $(ZSTDFLAGS) \
|
||||
--force \
|
||||
--ultra \
|
||||
--long \
|
||||
--no-progress \
|
||||
-o "build/coder-slim_$(VERSION).tar.zst" \
|
||||
@@ -323,7 +329,6 @@ build/coder_helm_$(VERSION).tgz:
|
||||
site/out/index.html: $(shell find ./site -not -path './site/node_modules/*' -type f -name '*.tsx') $(shell find ./site -not -path './site/node_modules/*' -type f -name '*.ts') site/package.json
|
||||
./scripts/yarn_install.sh
|
||||
cd site
|
||||
yarn typegen
|
||||
yarn build
|
||||
|
||||
install: build/coder_$(VERSION)_$(GOOS)_$(GOARCH)$(GOOS_BIN_EXT)
|
||||
@@ -380,7 +385,6 @@ lint/shellcheck: $(shell shfmt -f .)
|
||||
gen: \
|
||||
coderd/database/dump.sql \
|
||||
coderd/database/querier.go \
|
||||
peerbroker/proto/peerbroker.pb.go \
|
||||
provisionersdk/proto/provisioner.pb.go \
|
||||
provisionerd/proto/provisionerd.pb.go \
|
||||
site/src/api/typesGenerated.ts
|
||||
@@ -389,7 +393,7 @@ gen: \
|
||||
# Mark all generated files as fresh so make thinks they're up-to-date. This is
|
||||
# used during releases so we don't run generation scripts.
|
||||
gen/mark-fresh:
|
||||
files="coderd/database/dump.sql coderd/database/querier.go peerbroker/proto/peerbroker.pb.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts"
|
||||
files="coderd/database/dump.sql coderd/database/querier.go provisionersdk/proto/provisioner.pb.go provisionerd/proto/provisionerd.pb.go site/src/api/typesGenerated.ts"
|
||||
for file in $$files; do
|
||||
echo "$$file"
|
||||
if [ ! -f "$$file" ]; then
|
||||
@@ -405,20 +409,12 @@ gen/mark-fresh:
|
||||
# Runs migrations to output a dump of the database schema after migrations are
|
||||
# applied.
|
||||
coderd/database/dump.sql: coderd/database/gen/dump/main.go $(wildcard coderd/database/migrations/*.sql)
|
||||
go run coderd/database/gen/dump/main.go
|
||||
go run ./coderd/database/gen/dump/main.go
|
||||
|
||||
# Generates Go code for querying the database.
|
||||
coderd/database/querier.go: coderd/database/sqlc.yaml coderd/database/dump.sql $(wildcard coderd/database/queries/*.sql) coderd/database/gen/enum/main.go
|
||||
./coderd/database/generate.sh
|
||||
|
||||
peerbroker/proto/peerbroker.pb.go: peerbroker/proto/peerbroker.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
--go_opt=paths=source_relative \
|
||||
--go-drpc_out=. \
|
||||
--go-drpc_opt=paths=source_relative \
|
||||
./peerbroker/proto/peerbroker.proto
|
||||
|
||||
provisionersdk/proto/provisioner.pb.go: provisionersdk/proto/provisioner.proto
|
||||
protoc \
|
||||
--go_out=. \
|
||||
|
||||
@@ -52,13 +52,13 @@ You can modify the installation process by including flags. Run the help command
|
||||
curl -L https://coder.com/install.sh | sh -s -- --help
|
||||
```
|
||||
|
||||
> See [install](docs/install.md) for additional methods.
|
||||
> See [install](docs/install) for additional methods.
|
||||
|
||||
Once installed, you can start a production deployment<sup>1</sup> with a single command:
|
||||
|
||||
```sh
|
||||
# Automatically sets up an external access URL on *.try.coder.app
|
||||
coder server --tunnel
|
||||
coder server
|
||||
|
||||
# Requires a PostgreSQL instance and external access URL
|
||||
coder server --postgres-url <url> --access-url <url>
|
||||
|
||||
+119
-260
@@ -10,8 +10,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
@@ -34,8 +34,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/agent/usershell"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/pty"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/retry"
|
||||
@@ -52,42 +51,23 @@ const (
|
||||
MagicSessionErrorCode = 229
|
||||
)
|
||||
|
||||
var (
|
||||
// tailnetIP is a static IPv6 address with the Tailscale prefix that is used to route
|
||||
// connections from clients to this node. A dynamic address is not required because a Tailnet
|
||||
// client only dials a single agent at a time.
|
||||
tailnetIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
|
||||
tailnetSSHPort = 1
|
||||
tailnetReconnectingPTYPort = 2
|
||||
tailnetSpeedtestPort = 3
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
CoordinatorDialer CoordinatorDialer
|
||||
WebRTCDialer WebRTCDialer
|
||||
FetchMetadata FetchMetadata
|
||||
|
||||
StatsReporter StatsReporter
|
||||
ReconnectingPTYTimeout time.Duration
|
||||
EnvironmentVariables map[string]string
|
||||
Logger slog.Logger
|
||||
CoordinatorDialer CoordinatorDialer
|
||||
FetchMetadata FetchMetadata
|
||||
StatsReporter StatsReporter
|
||||
WorkspaceAgentApps WorkspaceAgentApps
|
||||
PostWorkspaceAgentAppHealth PostWorkspaceAgentAppHealth
|
||||
ReconnectingPTYTimeout time.Duration
|
||||
EnvironmentVariables map[string]string
|
||||
Logger slog.Logger
|
||||
}
|
||||
|
||||
type Metadata struct {
|
||||
DERPMap *tailcfg.DERPMap `json:"derpmap"`
|
||||
EnvironmentVariables map[string]string `json:"environment_variables"`
|
||||
StartupScript string `json:"startup_script"`
|
||||
Directory string `json:"directory"`
|
||||
}
|
||||
|
||||
type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
|
||||
|
||||
// CoordinatorDialer is a function that constructs a new broker.
|
||||
// A dialer must be passed in to allow for reconnects.
|
||||
type CoordinatorDialer func(ctx context.Context) (net.Conn, error)
|
||||
type CoordinatorDialer func(context.Context) (net.Conn, error)
|
||||
|
||||
// FetchMetadata is a function to obtain metadata for the agent.
|
||||
type FetchMetadata func(ctx context.Context) (Metadata, error)
|
||||
type FetchMetadata func(context.Context) (codersdk.WorkspaceAgentMetadata, error)
|
||||
|
||||
func New(options Options) io.Closer {
|
||||
if options.ReconnectingPTYTimeout == 0 {
|
||||
@@ -95,24 +75,24 @@ func New(options Options) io.Closer {
|
||||
}
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
server := &agent{
|
||||
webrtcDialer: options.WebRTCDialer,
|
||||
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
|
||||
logger: options.Logger,
|
||||
closeCancel: cancelFunc,
|
||||
closed: make(chan struct{}),
|
||||
envVars: options.EnvironmentVariables,
|
||||
coordinatorDialer: options.CoordinatorDialer,
|
||||
fetchMetadata: options.FetchMetadata,
|
||||
stats: &Stats{},
|
||||
statsReporter: options.StatsReporter,
|
||||
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
|
||||
logger: options.Logger,
|
||||
closeCancel: cancelFunc,
|
||||
closed: make(chan struct{}),
|
||||
envVars: options.EnvironmentVariables,
|
||||
coordinatorDialer: options.CoordinatorDialer,
|
||||
fetchMetadata: options.FetchMetadata,
|
||||
stats: &Stats{},
|
||||
statsReporter: options.StatsReporter,
|
||||
workspaceAgentApps: options.WorkspaceAgentApps,
|
||||
postWorkspaceAgentAppHealth: options.PostWorkspaceAgentAppHealth,
|
||||
}
|
||||
server.init(ctx)
|
||||
return server
|
||||
}
|
||||
|
||||
type agent struct {
|
||||
webrtcDialer WebRTCDialer
|
||||
logger slog.Logger
|
||||
logger slog.Logger
|
||||
|
||||
reconnectingPTYs sync.Map
|
||||
reconnectingPTYTimeout time.Duration
|
||||
@@ -128,14 +108,16 @@ type agent struct {
|
||||
fetchMetadata FetchMetadata
|
||||
sshServer *ssh.Server
|
||||
|
||||
network *tailnet.Conn
|
||||
coordinatorDialer CoordinatorDialer
|
||||
stats *Stats
|
||||
statsReporter StatsReporter
|
||||
network *tailnet.Conn
|
||||
coordinatorDialer CoordinatorDialer
|
||||
stats *Stats
|
||||
statsReporter StatsReporter
|
||||
workspaceAgentApps WorkspaceAgentApps
|
||||
postWorkspaceAgentAppHealth PostWorkspaceAgentAppHealth
|
||||
}
|
||||
|
||||
func (a *agent) run(ctx context.Context) {
|
||||
var metadata Metadata
|
||||
var metadata codersdk.WorkspaceAgentMetadata
|
||||
var err error
|
||||
// An exponential back-off occurs when the connection is failing to dial.
|
||||
// This is to prevent server spam in case of a coderd outage.
|
||||
@@ -173,12 +155,13 @@ func (a *agent) run(ctx context.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
if a.webrtcDialer != nil {
|
||||
go a.runWebRTCNetworking(ctx)
|
||||
}
|
||||
if metadata.DERPMap != nil {
|
||||
go a.runTailnet(ctx, metadata.DERPMap)
|
||||
}
|
||||
|
||||
if a.workspaceAgentApps != nil && a.postWorkspaceAgentAppHealth != nil {
|
||||
go NewWorkspaceAppHealthReporter(a.logger, a.workspaceAgentApps, a.postWorkspaceAgentAppHealth)(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
|
||||
@@ -187,13 +170,14 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
|
||||
if a.isClosed() {
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "running tailnet with derpmap", slog.F("derpmap", derpMap))
|
||||
if a.network != nil {
|
||||
a.network.SetDERPMap(derpMap)
|
||||
return
|
||||
}
|
||||
var err error
|
||||
a.network, err = tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnetIP, 128)},
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)},
|
||||
DERPMap: derpMap,
|
||||
Logger: a.logger.Named("tailnet"),
|
||||
})
|
||||
@@ -210,7 +194,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
|
||||
})
|
||||
go a.runCoordinator(ctx)
|
||||
|
||||
sshListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetSSHPort))
|
||||
sshListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSSHPort))
|
||||
if err != nil {
|
||||
a.logger.Critical(ctx, "listen for ssh", slog.Error(err))
|
||||
return
|
||||
@@ -224,7 +208,8 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
|
||||
go a.sshServer.HandleConn(a.stats.wrapConn(conn))
|
||||
}
|
||||
}()
|
||||
reconnectingPTYListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetReconnectingPTYPort))
|
||||
|
||||
reconnectingPTYListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetReconnectingPTYPort))
|
||||
if err != nil {
|
||||
a.logger.Critical(ctx, "listen for reconnecting pty", slog.Error(err))
|
||||
return
|
||||
@@ -250,7 +235,7 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var msg reconnectingPTYInit
|
||||
var msg codersdk.ReconnectingPTYInit
|
||||
err = json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -258,7 +243,8 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
|
||||
go a.handleReconnectingPTY(ctx, msg, conn)
|
||||
}
|
||||
}()
|
||||
speedtestListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetSpeedtestPort))
|
||||
|
||||
speedtestListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSpeedtestPort))
|
||||
if err != nil {
|
||||
a.logger.Critical(ctx, "listen for speedtest", slog.Error(err))
|
||||
return
|
||||
@@ -279,10 +265,44 @@ func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
statisticsListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetStatisticsPort))
|
||||
if err != nil {
|
||||
a.logger.Critical(ctx, "listen for statistics", slog.Error(err))
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer statisticsListener.Close()
|
||||
server := &http.Server{
|
||||
Handler: a.statisticsHandler(),
|
||||
ReadTimeout: 20 * time.Second,
|
||||
ReadHeaderTimeout: 20 * time.Second,
|
||||
WriteTimeout: 20 * time.Second,
|
||||
ErrorLog: slog.Stdlib(ctx, a.logger.Named("statistics_http_server"), slog.LevelInfo),
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = server.Close()
|
||||
}()
|
||||
|
||||
err = server.Serve(statisticsListener)
|
||||
if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
a.logger.Critical(ctx, "serve statistics HTTP server", slog.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// runCoordinator listens for nodes and updates the self-node as it changes.
|
||||
func (a *agent) runCoordinator(ctx context.Context) {
|
||||
for {
|
||||
reconnect := a.runCoordinatorWithRetry(ctx)
|
||||
if !reconnect {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) runCoordinatorWithRetry(ctx context.Context) (reconnect bool) {
|
||||
var coordinator net.Conn
|
||||
var err error
|
||||
// An exponential back-off occurs when the connection is failing to dial.
|
||||
@@ -291,81 +311,38 @@ func (a *agent) runCoordinator(ctx context.Context) {
|
||||
coordinator, err = a.coordinatorDialer(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
return false
|
||||
}
|
||||
if a.isClosed() {
|
||||
return
|
||||
return false
|
||||
}
|
||||
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
//nolint:revive // Defer is ok because we're exiting this loop.
|
||||
defer coordinator.Close()
|
||||
a.logger.Info(context.Background(), "connected to coordination server")
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return false
|
||||
default:
|
||||
}
|
||||
defer coordinator.Close()
|
||||
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, a.network.UpdateNodes)
|
||||
a.network.SetNodeCallback(sendNodes)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return false
|
||||
case err := <-errChan:
|
||||
if a.isClosed() {
|
||||
return
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
return false
|
||||
}
|
||||
a.logger.Debug(ctx, "node broker accept exited; restarting connection", slog.Error(err))
|
||||
a.runCoordinator(ctx)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) runWebRTCNetworking(ctx context.Context) {
|
||||
var peerListener *peerbroker.Listener
|
||||
var err error
|
||||
// An exponential back-off occurs when the connection is failing to dial.
|
||||
// This is to prevent server spam in case of a coderd outage.
|
||||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
|
||||
peerListener, err = a.webrtcDialer(ctx, a.logger)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
if a.isClosed() {
|
||||
return
|
||||
}
|
||||
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
a.logger.Info(context.Background(), "connected to webrtc broker")
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := peerListener.Accept()
|
||||
if err != nil {
|
||||
if a.isClosed() {
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
|
||||
a.runWebRTCNetworking(ctx)
|
||||
return
|
||||
}
|
||||
a.closeMutex.Lock()
|
||||
a.connCloseWait.Add(1)
|
||||
a.closeMutex.Unlock()
|
||||
go a.handlePeerConn(ctx, conn)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -374,7 +351,7 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
writer, err := os.OpenFile(filepath.Join(os.TempDir(), "coder-startup-script.log"), os.O_CREATE|os.O_RDWR, 0600)
|
||||
writer, err := os.OpenFile(filepath.Join(os.TempDir(), "coder-startup-script.log"), os.O_CREATE|os.O_RDWR, 0o600)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("open startup script log file: %w", err)
|
||||
}
|
||||
@@ -401,74 +378,6 @@ func (a *agent) runStartupScript(ctx context.Context, script string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) handlePeerConn(ctx context.Context, peerConn *peer.Conn) {
|
||||
go func() {
|
||||
select {
|
||||
case <-a.closed:
|
||||
case <-peerConn.Closed():
|
||||
}
|
||||
_ = peerConn.Close()
|
||||
a.connCloseWait.Done()
|
||||
}()
|
||||
for {
|
||||
channel, err := peerConn.Accept(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, peer.ErrClosed) || a.isClosed() {
|
||||
return
|
||||
}
|
||||
a.logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
conn := channel.NetConn()
|
||||
|
||||
switch channel.Protocol() {
|
||||
case ProtocolSSH:
|
||||
go a.sshServer.HandleConn(a.stats.wrapConn(conn))
|
||||
case ProtocolReconnectingPTY:
|
||||
rawID := channel.Label()
|
||||
// The ID format is referenced in conn.go.
|
||||
// <uuid>:<height>:<width>
|
||||
idParts := strings.SplitN(rawID, ":", 4)
|
||||
if len(idParts) != 4 {
|
||||
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
|
||||
continue
|
||||
}
|
||||
id := idParts[0]
|
||||
// Enforce a consistent format for IDs.
|
||||
_, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
// Parse the initial terminal dimensions.
|
||||
height, err := strconv.Atoi(idParts[1])
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
|
||||
continue
|
||||
}
|
||||
width, err := strconv.Atoi(idParts[2])
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
|
||||
continue
|
||||
}
|
||||
go a.handleReconnectingPTY(ctx, reconnectingPTYInit{
|
||||
ID: id,
|
||||
Height: uint16(height),
|
||||
Width: uint16(width),
|
||||
Command: idParts[3],
|
||||
}, a.stats.wrapConn(conn))
|
||||
case ProtocolDial:
|
||||
go a.handleDial(ctx, channel.Label(), a.stats.wrapConn(conn))
|
||||
default:
|
||||
a.logger.Warn(ctx, "unhandled protocol from channel",
|
||||
slog.F("protocol", channel.Protocol()),
|
||||
slog.F("label", channel.Label()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *agent) init(ctx context.Context) {
|
||||
a.logger.Info(ctx, "generating host key")
|
||||
// Clients' should ignore the host key when connecting.
|
||||
@@ -537,6 +446,8 @@ func (a *agent) init(ctx context.Context) {
|
||||
},
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": func(session ssh.Session) {
|
||||
session.DisablePTYEmulation()
|
||||
|
||||
server, err := sftp.NewServer(session)
|
||||
if err != nil {
|
||||
a.logger.Debug(session.Context(), "initialize sftp server", slog.Error(err))
|
||||
@@ -554,7 +465,7 @@ func (a *agent) init(ctx context.Context) {
|
||||
|
||||
go a.run(ctx)
|
||||
if a.statsReporter != nil {
|
||||
cl, err := a.statsReporter(ctx, a.logger, func() *Stats {
|
||||
cl, err := a.statsReporter(ctx, a.logger, func() *codersdk.AgentStats {
|
||||
return a.stats.Copy()
|
||||
})
|
||||
if err != nil {
|
||||
@@ -589,7 +500,7 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
|
||||
if rawMetadata == nil {
|
||||
return nil, xerrors.Errorf("no metadata was provided: %w", err)
|
||||
}
|
||||
metadata, valid := rawMetadata.(Metadata)
|
||||
metadata, valid := rawMetadata.(codersdk.WorkspaceAgentMetadata)
|
||||
if !valid {
|
||||
return nil, xerrors.Errorf("metadata is the wrong type: %T", metadata)
|
||||
}
|
||||
@@ -661,7 +572,8 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
|
||||
}
|
||||
|
||||
func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
||||
cmd, err := a.createCommand(session.Context(), session.RawCommand(), session.Environ())
|
||||
ctx := session.Context()
|
||||
cmd, err := a.createCommand(ctx, session.RawCommand(), session.Environ())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -678,32 +590,34 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
||||
|
||||
sshPty, windowSize, isPty := session.Pty()
|
||||
if isPty {
|
||||
// Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL).
|
||||
// See https://github.com/coder/coder/issues/3371.
|
||||
session.DisablePTYEmulation()
|
||||
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))
|
||||
|
||||
// The pty package sets `SSH_TTY` on supported platforms.
|
||||
ptty, process, err := pty.Start(cmd)
|
||||
ptty, process, err := pty.Start(cmd, pty.WithPTYOption(
|
||||
pty.WithSSHRequest(sshPty),
|
||||
pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)),
|
||||
))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("start command: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
closeErr := ptty.Close()
|
||||
if closeErr != nil {
|
||||
a.logger.Warn(context.Background(), "failed to close tty",
|
||||
slog.Error(closeErr))
|
||||
a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr))
|
||||
if retErr == nil {
|
||||
retErr = closeErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
err = ptty.Resize(uint16(sshPty.Window.Height), uint16(sshPty.Window.Width))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("resize ptty: %w", err)
|
||||
}
|
||||
go func() {
|
||||
for win := range windowSize {
|
||||
resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width))
|
||||
if resizeErr != nil {
|
||||
a.logger.Warn(context.Background(), "failed to resize tty", slog.Error(resizeErr))
|
||||
a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr))
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -718,8 +632,7 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
||||
// ExitErrors just mean the command we run returned a non-zero exit code, which is normal
|
||||
// and not something to be concerned about. But, if it's something else, we should log it.
|
||||
if err != nil && !xerrors.As(err, &exitErr) {
|
||||
a.logger.Warn(context.Background(), "wait error",
|
||||
slog.Error(err))
|
||||
a.logger.Warn(ctx, "wait error", slog.Error(err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -743,7 +656,7 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
|
||||
return cmd.Wait()
|
||||
}
|
||||
|
||||
func (a *agent) handleReconnectingPTY(ctx context.Context, msg reconnectingPTYInit, conn net.Conn) {
|
||||
func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.ReconnectingPTYInit, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
var rpty *reconnectingPTY
|
||||
@@ -884,7 +797,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg reconnectingPTYIn
|
||||
rpty.activeConnsMutex.Unlock()
|
||||
}()
|
||||
decoder := json.NewDecoder(conn)
|
||||
var req ReconnectingPTYRequest
|
||||
var req codersdk.ReconnectingPTYRequest
|
||||
for {
|
||||
err = decoder.Decode(&req)
|
||||
if xerrors.Is(err, io.EOF) {
|
||||
@@ -911,70 +824,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg reconnectingPTYIn
|
||||
}
|
||||
}
|
||||
|
||||
// dialResponse is written to datachannels with protocol "dial" by the agent as
|
||||
// the first packet to signify whether the dial succeeded or failed.
|
||||
type dialResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (a *agent) handleDial(ctx context.Context, label string, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
writeError := func(responseError error) error {
|
||||
msg := ""
|
||||
if responseError != nil {
|
||||
msg = responseError.Error()
|
||||
if !xerrors.Is(responseError, io.EOF) {
|
||||
a.logger.Warn(ctx, "handle dial", slog.F("label", label), slog.Error(responseError))
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(dialResponse{
|
||||
Error: msg,
|
||||
})
|
||||
if err != nil {
|
||||
a.logger.Warn(ctx, "write dial response", slog.F("label", label), slog.Error(err))
|
||||
return xerrors.Errorf("marshal agent webrtc dial response: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := url.Parse(label)
|
||||
if err != nil {
|
||||
_ = writeError(xerrors.Errorf("parse URL %q: %w", label, err))
|
||||
return
|
||||
}
|
||||
|
||||
network := u.Scheme
|
||||
addr := u.Host + u.Path
|
||||
if strings.HasPrefix(network, "unix") {
|
||||
if runtime.GOOS == "windows" {
|
||||
_ = writeError(xerrors.New("Unix forwarding is not supported from Windows workspaces"))
|
||||
return
|
||||
}
|
||||
addr, err = ExpandRelativeHomePath(addr)
|
||||
if err != nil {
|
||||
_ = writeError(xerrors.Errorf("expand path %q: %w", addr, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d := net.Dialer{Timeout: 3 * time.Second}
|
||||
nconn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
_ = writeError(xerrors.Errorf("dial '%v://%v': %w", network, addr, err))
|
||||
return
|
||||
}
|
||||
|
||||
err = writeError(nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
Bicopy(ctx, conn, nconn)
|
||||
}
|
||||
|
||||
// isClosed returns whether the API is closed or not.
|
||||
func (a *agent) isClosed() bool {
|
||||
select {
|
||||
@@ -1030,12 +879,22 @@ func (r *reconnectingPTY) Close() {
|
||||
// after one or both of them are done writing. If the context is canceled, both
|
||||
// of the connections will be closed.
|
||||
func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
|
||||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
defer func() {
|
||||
_ = c1.Close()
|
||||
_ = c2.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
copyFunc := func(dst io.WriteCloser, src io.Reader) {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
wg.Done()
|
||||
// If one side of the copy fails, ensure the other one exits as
|
||||
// well.
|
||||
cancel()
|
||||
}()
|
||||
_, _ = io.Copy(dst, src)
|
||||
}
|
||||
|
||||
|
||||
+92
-179
@@ -20,12 +20,10 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
scp "github.com/bramvdbogaerde/go-scp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/udp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -37,10 +35,7 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/tailnet/tailnettest"
|
||||
@@ -54,69 +49,54 @@ func TestMain(m *testing.M) {
|
||||
func TestAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
for _, tailscale := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("tailscale=%v", tailscale), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Parallel()
|
||||
|
||||
setupAgent := func(t *testing.T) (agent.Conn, <-chan *agent.Stats) {
|
||||
var derpMap *tailcfg.DERPMap
|
||||
if tailscale {
|
||||
derpMap = tailnettest.RunDERPAndSTUN(t)
|
||||
}
|
||||
conn, stats := setupAgent(t, agent.Metadata{
|
||||
DERPMap: derpMap,
|
||||
}, 0)
|
||||
assert.Empty(t, <-stats)
|
||||
return conn, stats
|
||||
}
|
||||
t.Run("SSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
|
||||
t.Run("SSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, stats := setupAgent(t)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
assert.EqualValues(t, 1, (<-stats).NumConns)
|
||||
assert.Greater(t, (<-stats).RxBytes, int64(0))
|
||||
assert.Greater(t, (<-stats).TxBytes, int64(0))
|
||||
})
|
||||
|
||||
assert.EqualValues(t, 1, (<-stats).NumConns)
|
||||
assert.Greater(t, (<-stats).RxBytes, int64(0))
|
||||
assert.Greater(t, (<-stats).TxBytes, int64(0))
|
||||
})
|
||||
t.Run("ReconnectingPTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReconnectingPTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
|
||||
conn, stats := setupAgent(t)
|
||||
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
defer ptyConn.Close()
|
||||
|
||||
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
defer ptyConn.Close()
|
||||
|
||||
data, err := json.Marshal(agent.ReconnectingPTYRequest{
|
||||
Data: "echo test\r\n",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ptyConn.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
var s *agent.Stats
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var ok bool
|
||||
s, ok = (<-stats)
|
||||
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"never saw stats: %+v", s,
|
||||
)
|
||||
})
|
||||
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
|
||||
Data: "echo test\r\n",
|
||||
})
|
||||
}
|
||||
require.NoError(t, err)
|
||||
_, err = ptyConn.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
var s *codersdk.AgentStats
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var ok bool
|
||||
s, ok = (<-stats)
|
||||
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"never saw stats: %+v", s,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("SessionExec", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
session := setupSSHSession(t, agent.Metadata{})
|
||||
session := setupSSHSession(t, codersdk.WorkspaceAgentMetadata{})
|
||||
|
||||
command := "echo test"
|
||||
if runtime.GOOS == "windows" {
|
||||
@@ -129,7 +109,7 @@ func TestAgent(t *testing.T) {
|
||||
|
||||
t.Run("GitSSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
session := setupSSHSession(t, agent.Metadata{})
|
||||
session := setupSSHSession(t, codersdk.WorkspaceAgentMetadata{})
|
||||
command := "sh -c 'echo $GIT_SSH_COMMAND'"
|
||||
if runtime.GOOS == "windows" {
|
||||
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
|
||||
@@ -147,7 +127,7 @@ func TestAgent(t *testing.T) {
|
||||
// it seems like it could be either.
|
||||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||||
}
|
||||
session := setupSSHSession(t, agent.Metadata{})
|
||||
session := setupSSHSession(t, codersdk.WorkspaceAgentMetadata{})
|
||||
command := "bash"
|
||||
if runtime.GOOS == "windows" {
|
||||
command = "cmd.exe"
|
||||
@@ -175,7 +155,7 @@ func TestAgent(t *testing.T) {
|
||||
|
||||
t.Run("SessionTTYExitCode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
session := setupSSHSession(t, agent.Metadata{})
|
||||
session := setupSSHSession(t, codersdk.WorkspaceAgentMetadata{})
|
||||
command := "areallynotrealcommand"
|
||||
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
|
||||
require.NoError(t, err)
|
||||
@@ -232,9 +212,10 @@ func TestAgent(t *testing.T) {
|
||||
|
||||
t.Run("SFTP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
client, err := sftp.NewClient(sshClient)
|
||||
require.NoError(t, err)
|
||||
tempFile := filepath.Join(t.TempDir(), "sftp")
|
||||
@@ -249,9 +230,10 @@ func TestAgent(t *testing.T) {
|
||||
t.Run("SCP", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
scpClient, err := scp.NewClientBySSH(sshClient)
|
||||
require.NoError(t, err)
|
||||
tempFile := filepath.Join(t.TempDir(), "scp")
|
||||
@@ -266,7 +248,7 @@ func TestAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := "EXAMPLE"
|
||||
value := "value"
|
||||
session := setupSSHSession(t, agent.Metadata{
|
||||
session := setupSSHSession(t, codersdk.WorkspaceAgentMetadata{
|
||||
EnvironmentVariables: map[string]string{
|
||||
key: value,
|
||||
},
|
||||
@@ -283,7 +265,7 @@ func TestAgent(t *testing.T) {
|
||||
t.Run("EnvironmentVariableExpansion", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := "EXAMPLE"
|
||||
session := setupSSHSession(t, agent.Metadata{
|
||||
session := setupSSHSession(t, codersdk.WorkspaceAgentMetadata{
|
||||
EnvironmentVariables: map[string]string{
|
||||
key: "$SOMETHINGNOTSET",
|
||||
},
|
||||
@@ -310,7 +292,7 @@ func TestAgent(t *testing.T) {
|
||||
t.Run(key, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
session := setupSSHSession(t, agent.Metadata{})
|
||||
session := setupSSHSession(t, codersdk.WorkspaceAgentMetadata{})
|
||||
command := "sh -c 'echo $" + key + "'"
|
||||
if runtime.GOOS == "windows" {
|
||||
command = "cmd.exe /c echo %" + key + "%"
|
||||
@@ -333,7 +315,7 @@ func TestAgent(t *testing.T) {
|
||||
t.Run(key, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
session := setupSSHSession(t, agent.Metadata{})
|
||||
session := setupSSHSession(t, codersdk.WorkspaceAgentMetadata{})
|
||||
command := "sh -c 'echo $" + key + "'"
|
||||
if runtime.GOOS == "windows" {
|
||||
command = "cmd.exe /c echo %" + key + "%"
|
||||
@@ -349,7 +331,7 @@ func TestAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
tempPath := filepath.Join(t.TempDir(), "content.txt")
|
||||
content := "somethingnice"
|
||||
setupAgent(t, agent.Metadata{
|
||||
setupAgent(t, codersdk.WorkspaceAgentMetadata{
|
||||
StartupScript: fmt.Sprintf("echo %s > %s", content, tempPath),
|
||||
}, 0)
|
||||
|
||||
@@ -384,9 +366,7 @@ func TestAgent(t *testing.T) {
|
||||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||||
}
|
||||
|
||||
conn, _ := setupAgent(t, agent.Metadata{
|
||||
DERPMap: tailnettest.RunDERPAndSTUN(t),
|
||||
}, 0)
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
id := uuid.NewString()
|
||||
netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
@@ -396,7 +376,7 @@ func TestAgent(t *testing.T) {
|
||||
// the shell is simultaneously sending a prompt.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
data, err := json.Marshal(agent.ReconnectingPTYRequest{
|
||||
data, err := json.Marshal(codersdk.ReconnectingPTYRequest{
|
||||
Data: "echo test\r\n",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -462,19 +442,6 @@ func TestAgent(t *testing.T) {
|
||||
return l
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Unix",
|
||||
setup: func(t *testing.T) net.Listener {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix socket forwarding isn't supported on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.NoError(t, err, "create UDP listener")
|
||||
return l
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
@@ -496,8 +463,11 @@ func TestAgent(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Dial the listener over WebRTC twice and test out of order
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := conn.Ping(context.Background())
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
@@ -506,67 +476,27 @@ func TestAgent(t *testing.T) {
|
||||
defer conn2.Close()
|
||||
testDial(t, conn2)
|
||||
testDial(t, conn1)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// This test uses Unix listeners so we can very easily ensure that
|
||||
// no other tests decide to listen on the same random port we
|
||||
// picked.
|
||||
t.Skip("this test is unsupported on Windows")
|
||||
return
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||
require.NoError(t, err, "create temp dir")
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
// Try to dial the non-existent Unix socket over WebRTC
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "remote dial error")
|
||||
require.ErrorContains(t, err, "no such file")
|
||||
require.Nil(t, netConn)
|
||||
})
|
||||
|
||||
t.Run("Tailnet", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
conn, _ := setupAgent(t, agent.Metadata{
|
||||
DERPMap: derpMap,
|
||||
}, 0)
|
||||
defer conn.Close()
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := conn.Ping()
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
})
|
||||
|
||||
t.Run("Speedtest", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.Skip("The minimum duration for a speedtest is hardcoded in Tailscale to 5s!")
|
||||
}
|
||||
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
conn, _ := setupAgent(t, agent.Metadata{
|
||||
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{
|
||||
DERPMap: derpMap,
|
||||
}, 0)
|
||||
defer conn.Close()
|
||||
res, err := conn.Speedtest(speedtest.Upload, speedtest.MinDuration)
|
||||
res, err := conn.Speedtest(speedtest.Upload, 250*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond())
|
||||
})
|
||||
}
|
||||
|
||||
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
|
||||
agentConn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
agentConn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
waitGroup := sync.WaitGroup{}
|
||||
@@ -578,7 +508,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
|
||||
return
|
||||
}
|
||||
ssh, err := agentConn.SSH()
|
||||
if !assert.NoError(t, err) {
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
@@ -603,7 +533,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
|
||||
return exec.Command("ssh", args...)
|
||||
}
|
||||
|
||||
func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
|
||||
func setupSSHSession(t *testing.T, options codersdk.WorkspaceAgentMetadata) *ssh.Session {
|
||||
conn, _ := setupAgent(t, options, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
@@ -621,35 +551,37 @@ func (c closeFunc) Close() error {
|
||||
return c()
|
||||
}
|
||||
|
||||
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) (
|
||||
agent.Conn,
|
||||
<-chan *agent.Stats,
|
||||
func setupAgent(t *testing.T, metadata codersdk.WorkspaceAgentMetadata, ptyTimeout time.Duration) (
|
||||
*codersdk.AgentConn,
|
||||
<-chan *codersdk.AgentStats,
|
||||
) {
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
tailscale := metadata.DERPMap != nil
|
||||
if metadata.DERPMap == nil {
|
||||
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
|
||||
}
|
||||
coordinator := tailnet.NewCoordinator()
|
||||
agentID := uuid.New()
|
||||
statsCh := make(chan *agent.Stats)
|
||||
statsCh := make(chan *codersdk.AgentStats)
|
||||
closer := agent.New(agent.Options{
|
||||
FetchMetadata: func(ctx context.Context) (agent.Metadata, error) {
|
||||
FetchMetadata: func(ctx context.Context) (codersdk.WorkspaceAgentMetadata, error) {
|
||||
return metadata, nil
|
||||
},
|
||||
WebRTCDialer: func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
|
||||
listener, err := peerbroker.Listen(server, nil)
|
||||
return listener, err
|
||||
},
|
||||
CoordinatorDialer: func(ctx context.Context) (net.Conn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
closed := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
_ = serverConn.Close()
|
||||
_ = clientConn.Close()
|
||||
<-closed
|
||||
})
|
||||
go coordinator.ServeAgent(serverConn, agentID)
|
||||
go func() {
|
||||
_ = coordinator.ServeAgent(serverConn, agentID)
|
||||
close(closed)
|
||||
}()
|
||||
return clientConn, nil
|
||||
},
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
|
||||
ReconnectingPTYTimeout: ptyTimeout,
|
||||
StatsReporter: func(ctx context.Context, log slog.Logger, statsFn func() *agent.Stats) (io.Closer, error) {
|
||||
StatsReporter: func(ctx context.Context, log slog.Logger, statsFn func() *codersdk.AgentStats) (io.Closer, error) {
|
||||
doneCh := make(chan struct{})
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
@@ -683,46 +615,27 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
|
||||
},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = client.Close()
|
||||
_ = server.Close()
|
||||
_ = closer.Close()
|
||||
})
|
||||
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
|
||||
stream, err := api.NegotiateConnection(context.Background())
|
||||
assert.NoError(t, err)
|
||||
if tailscale {
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: metadata.DERPMap,
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
})
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return &agent.TailnetConn{
|
||||
Conn: conn,
|
||||
}, statsCh
|
||||
}
|
||||
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil),
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: metadata.DERPMap,
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
})
|
||||
|
||||
return &agent.WebRTCConn{
|
||||
Negotiator: api,
|
||||
Conn: conn,
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return &codersdk.AgentConn{
|
||||
Conn: conn,
|
||||
}, statsCh
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
||||
// WorkspaceAgentApps fetches the workspace apps.
|
||||
type WorkspaceAgentApps func(context.Context) ([]codersdk.WorkspaceApp, error)
|
||||
|
||||
// PostWorkspaceAgentAppHealth updates the workspace app health.
|
||||
type PostWorkspaceAgentAppHealth func(context.Context, codersdk.PostWorkspaceAppHealthsRequest) error
|
||||
|
||||
// WorkspaceAppHealthReporter is a function that checks and reports the health of the workspace apps until the passed context is canceled.
|
||||
type WorkspaceAppHealthReporter func(ctx context.Context)
|
||||
|
||||
// NewWorkspaceAppHealthReporter creates a WorkspaceAppHealthReporter that reports app health to coderd.
|
||||
func NewWorkspaceAppHealthReporter(logger slog.Logger, workspaceAgentApps WorkspaceAgentApps, postWorkspaceAgentAppHealth PostWorkspaceAgentAppHealth) WorkspaceAppHealthReporter {
|
||||
runHealthcheckLoop := func(ctx context.Context) error {
|
||||
apps, err := workspaceAgentApps(ctx)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return xerrors.Errorf("getting workspace apps: %w", err)
|
||||
}
|
||||
|
||||
// no need to run this loop if no apps for this workspace.
|
||||
if len(apps) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
hasHealthchecksEnabled := false
|
||||
health := make(map[string]codersdk.WorkspaceAppHealth, 0)
|
||||
for _, app := range apps {
|
||||
health[app.Name] = app.Health
|
||||
if !hasHealthchecksEnabled && app.Health != codersdk.WorkspaceAppHealthDisabled {
|
||||
hasHealthchecksEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
// no need to run this loop if no health checks are configured.
|
||||
if !hasHealthchecksEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// run a ticker for each app health check.
|
||||
var mu sync.RWMutex
|
||||
failures := make(map[string]int, 0)
|
||||
for _, nextApp := range apps {
|
||||
if !shouldStartTicker(nextApp) {
|
||||
continue
|
||||
}
|
||||
app := nextApp
|
||||
t := time.NewTicker(time.Duration(app.Healthcheck.Interval) * time.Second)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
}
|
||||
// we set the http timeout to the healthcheck interval to prevent getting too backed up.
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(app.Healthcheck.Interval) * time.Second,
|
||||
}
|
||||
err := func() error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, app.Healthcheck.URL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// successful healthcheck is a non-5XX status code
|
||||
res.Body.Close()
|
||||
if res.StatusCode >= http.StatusInternalServerError {
|
||||
return xerrors.Errorf("error status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
if failures[app.Name] < int(app.Healthcheck.Threshold) {
|
||||
// increment the failure count and keep status the same.
|
||||
// we will change it when we hit the threshold.
|
||||
failures[app.Name]++
|
||||
} else {
|
||||
// set to unhealthy if we hit the failure threshold.
|
||||
// we stop incrementing at the threshold to prevent the failure value from increasing forever.
|
||||
health[app.Name] = codersdk.WorkspaceAppHealthUnhealthy
|
||||
}
|
||||
mu.Unlock()
|
||||
} else {
|
||||
mu.Lock()
|
||||
// we only need one successful health check to be considered healthy.
|
||||
health[app.Name] = codersdk.WorkspaceAppHealthHealthy
|
||||
failures[app.Name] = 0
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
t.Reset(time.Duration(app.Healthcheck.Interval) * time.Second)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
lastHealth := copyHealth(health)
|
||||
mu.Unlock()
|
||||
reportTicker := time.NewTicker(time.Second)
|
||||
// every second we check if the health values of the apps have changed
|
||||
// and if there is a change we will report the new values.
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-reportTicker.C:
|
||||
mu.RLock()
|
||||
changed := healthChanged(lastHealth, health)
|
||||
mu.RUnlock()
|
||||
if !changed {
|
||||
continue
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
lastHealth = copyHealth(health)
|
||||
mu.Unlock()
|
||||
err := postWorkspaceAgentAppHealth(ctx, codersdk.PostWorkspaceAppHealthsRequest{
|
||||
Healths: lastHealth,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to report workspace app stat", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return func(ctx context.Context) {
|
||||
for r := retry.New(time.Second, 30*time.Second); r.Wait(ctx); {
|
||||
err := runHealthcheckLoop(ctx)
|
||||
if err == nil || xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
logger.Error(ctx, "failed running workspace app reporter", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldStartTicker(app codersdk.WorkspaceApp) bool {
|
||||
return app.Healthcheck.URL != "" && app.Healthcheck.Interval > 0 && app.Healthcheck.Threshold > 0
|
||||
}
|
||||
|
||||
func healthChanged(old map[string]codersdk.WorkspaceAppHealth, new map[string]codersdk.WorkspaceAppHealth) bool {
|
||||
for name, newValue := range new {
|
||||
oldValue, found := old[name]
|
||||
if !found {
|
||||
return true
|
||||
}
|
||||
if newValue != oldValue {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func copyHealth(h1 map[string]codersdk.WorkspaceAppHealth) map[string]codersdk.WorkspaceAppHealth {
|
||||
h2 := make(map[string]codersdk.WorkspaceAppHealth, 0)
|
||||
for k, v := range h1 {
|
||||
h2[k] = v
|
||||
}
|
||||
|
||||
return h2
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package agent_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestAppHealth(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Healthy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
apps := []codersdk.WorkspaceApp{
|
||||
{
|
||||
Name: "app1",
|
||||
Healthcheck: codersdk.Healthcheck{},
|
||||
Health: codersdk.WorkspaceAppHealthDisabled,
|
||||
},
|
||||
{
|
||||
Name: "app2",
|
||||
Healthcheck: codersdk.Healthcheck{
|
||||
// URL: We don't set the URL for this test because the setup will
|
||||
// create a httptest server for us and set it for us.
|
||||
Interval: 1,
|
||||
Threshold: 1,
|
||||
},
|
||||
Health: codersdk.WorkspaceAppHealthInitializing,
|
||||
},
|
||||
}
|
||||
handlers := []http.Handler{
|
||||
nil,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, nil)
|
||||
}),
|
||||
}
|
||||
getApps, closeFn := setupAppReporter(ctx, t, apps, handlers)
|
||||
defer closeFn()
|
||||
apps, err := getApps(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, codersdk.WorkspaceAppHealthDisabled, apps[0].Health)
|
||||
require.Eventually(t, func() bool {
|
||||
apps, err := getApps(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return apps[1].Health == codersdk.WorkspaceAppHealthHealthy
|
||||
}, testutil.WaitLong, testutil.IntervalSlow)
|
||||
})
|
||||
|
||||
t.Run("500", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
apps := []codersdk.WorkspaceApp{
|
||||
{
|
||||
Name: "app2",
|
||||
Healthcheck: codersdk.Healthcheck{
|
||||
// URL: We don't set the URL for this test because the setup will
|
||||
// create a httptest server for us and set it for us.
|
||||
Interval: 1,
|
||||
Threshold: 1,
|
||||
},
|
||||
Health: codersdk.WorkspaceAppHealthInitializing,
|
||||
},
|
||||
}
|
||||
handlers := []http.Handler{
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), w, http.StatusInternalServerError, nil)
|
||||
}),
|
||||
}
|
||||
getApps, closeFn := setupAppReporter(ctx, t, apps, handlers)
|
||||
defer closeFn()
|
||||
require.Eventually(t, func() bool {
|
||||
apps, err := getApps(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return apps[0].Health == codersdk.WorkspaceAppHealthUnhealthy
|
||||
}, testutil.WaitLong, testutil.IntervalSlow)
|
||||
})
|
||||
|
||||
t.Run("Timeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
apps := []codersdk.WorkspaceApp{
|
||||
{
|
||||
Name: "app2",
|
||||
Healthcheck: codersdk.Healthcheck{
|
||||
// URL: We don't set the URL for this test because the setup will
|
||||
// create a httptest server for us and set it for us.
|
||||
Interval: 1,
|
||||
Threshold: 1,
|
||||
},
|
||||
Health: codersdk.WorkspaceAppHealthInitializing,
|
||||
},
|
||||
}
|
||||
handlers := []http.Handler{
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// sleep longer than the interval to cause the health check to time out
|
||||
time.Sleep(2 * time.Second)
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, nil)
|
||||
}),
|
||||
}
|
||||
getApps, closeFn := setupAppReporter(ctx, t, apps, handlers)
|
||||
defer closeFn()
|
||||
require.Eventually(t, func() bool {
|
||||
apps, err := getApps(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return apps[0].Health == codersdk.WorkspaceAppHealthUnhealthy
|
||||
}, testutil.WaitLong, testutil.IntervalSlow)
|
||||
})
|
||||
|
||||
t.Run("NotSpamming", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
apps := []codersdk.WorkspaceApp{
|
||||
{
|
||||
Name: "app2",
|
||||
Healthcheck: codersdk.Healthcheck{
|
||||
// URL: We don't set the URL for this test because the setup will
|
||||
// create a httptest server for us and set it for us.
|
||||
Interval: 1,
|
||||
Threshold: 1,
|
||||
},
|
||||
Health: codersdk.WorkspaceAppHealthInitializing,
|
||||
},
|
||||
}
|
||||
|
||||
var counter = new(int32)
|
||||
handlers := []http.Handler{
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(counter, 1)
|
||||
}),
|
||||
}
|
||||
_, closeFn := setupAppReporter(ctx, t, apps, handlers)
|
||||
defer closeFn()
|
||||
// Ensure we haven't made more than 2 (expected 1 + 1 for buffer) requests in the last second.
|
||||
// if there is a bug where we are spamming the healthcheck route this will catch it.
|
||||
time.Sleep(time.Second)
|
||||
require.LessOrEqual(t, *counter, int32(2))
|
||||
})
|
||||
}
|
||||
|
||||
func setupAppReporter(ctx context.Context, t *testing.T, apps []codersdk.WorkspaceApp, handlers []http.Handler) (agent.WorkspaceAgentApps, func()) {
|
||||
closers := []func(){}
|
||||
for i, handler := range handlers {
|
||||
if handler == nil {
|
||||
continue
|
||||
}
|
||||
ts := httptest.NewServer(handler)
|
||||
app := apps[i]
|
||||
app.Healthcheck.URL = ts.URL
|
||||
apps[i] = app
|
||||
closers = append(closers, ts.Close)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
workspaceAgentApps := func(context.Context) ([]codersdk.WorkspaceApp, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
var newApps []codersdk.WorkspaceApp
|
||||
return append(newApps, apps...), nil
|
||||
}
|
||||
postWorkspaceAgentAppHealth := func(_ context.Context, req codersdk.PostWorkspaceAppHealthsRequest) error {
|
||||
mu.Lock()
|
||||
for name, health := range req.Healths {
|
||||
for i, app := range apps {
|
||||
if app.Name != name {
|
||||
continue
|
||||
}
|
||||
app.Health = health
|
||||
apps[i] = app
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
go agent.NewWorkspaceAppHealthReporter(slogtest.Make(t, nil).Leveled(slog.LevelDebug), workspaceAgentApps, postWorkspaceAgentAppHealth)(ctx)
|
||||
|
||||
return workspaceAgentApps, func() {
|
||||
for _, closeFn := range closers {
|
||||
closeFn()
|
||||
}
|
||||
}
|
||||
}
|
||||
-256
@@ -1,256 +0,0 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/tailnet"
|
||||
)
|
||||
|
||||
// ReconnectingPTYRequest is sent from the client to the server
|
||||
// to pipe data to a PTY.
|
||||
type ReconnectingPTYRequest struct {
|
||||
Data string `json:"data"`
|
||||
Height uint16 `json:"height"`
|
||||
Width uint16 `json:"width"`
|
||||
}
|
||||
|
||||
// Conn is a temporary interface while we switch from WebRTC to Wireguard networking.
|
||||
type Conn interface {
|
||||
io.Closer
|
||||
Closed() <-chan struct{}
|
||||
Ping() (time.Duration, error)
|
||||
CloseWithError(err error) error
|
||||
ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error)
|
||||
SSH() (net.Conn, error)
|
||||
Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error)
|
||||
SSHClient() (*ssh.Client, error)
|
||||
DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// Conn wraps a peer connection with helper functions to
|
||||
// communicate with the agent.
|
||||
type WebRTCConn struct {
|
||||
// Negotiator is responsible for exchanging messages.
|
||||
Negotiator proto.DRPCPeerBrokerClient
|
||||
|
||||
*peer.Conn
|
||||
}
|
||||
|
||||
// ReconnectingPTY returns a connection serving a TTY that can
|
||||
// be reconnected to via ID.
|
||||
//
|
||||
// The command is optional and defaults to start a shell.
|
||||
func (c *WebRTCConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
|
||||
channel, err := c.CreateChannel(context.Background(), fmt.Sprintf("%s:%d:%d:%s", id, height, width, command), &peer.ChannelOptions{
|
||||
Protocol: ProtocolReconnectingPTY,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("pty: %w", err)
|
||||
}
|
||||
return channel.NetConn(), nil
|
||||
}
|
||||
|
||||
// SSH dials the built-in SSH server.
|
||||
func (c *WebRTCConn) SSH() (net.Conn, error) {
|
||||
channel, err := c.CreateChannel(context.Background(), "ssh", &peer.ChannelOptions{
|
||||
Protocol: ProtocolSSH,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial: %w", err)
|
||||
}
|
||||
return channel.NetConn(), nil
|
||||
}
|
||||
|
||||
func (*WebRTCConn) Speedtest(_ speedtest.Direction, _ time.Duration) ([]speedtest.Result, error) {
|
||||
return nil, xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
// SSHClient calls SSH to create a client that uses a weak cipher
|
||||
// for high throughput.
|
||||
func (c *WebRTCConn) SSHClient() (*ssh.Client, error) {
|
||||
netConn, err := c.SSH()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh: %w", err)
|
||||
}
|
||||
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
|
||||
// SSH host validation isn't helpful, because obtaining a peer
|
||||
// connection already signifies user-intent to dial a workspace.
|
||||
// #nosec
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh conn: %w", err)
|
||||
}
|
||||
return ssh.NewClient(sshConn, channels, requests), nil
|
||||
}
|
||||
|
||||
// DialContext dials an arbitrary protocol+address from inside the workspace and
|
||||
// proxies it through the provided net.Conn.
|
||||
func (c *WebRTCConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
u := &url.URL{
|
||||
Scheme: network,
|
||||
}
|
||||
if strings.HasPrefix(network, "unix") {
|
||||
u.Path = addr
|
||||
} else {
|
||||
u.Host = addr
|
||||
}
|
||||
|
||||
channel, err := c.CreateChannel(ctx, u.String(), &peer.ChannelOptions{
|
||||
Protocol: ProtocolDial,
|
||||
Unordered: strings.HasPrefix(network, "udp"),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create datachannel: %w", err)
|
||||
}
|
||||
|
||||
// The first message written from the other side is a JSON payload
|
||||
// containing the dial error.
|
||||
dec := json.NewDecoder(channel)
|
||||
var res dialResponse
|
||||
err = dec.Decode(&res)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("decode agent dial response: %w", err)
|
||||
}
|
||||
if res.Error != "" {
|
||||
_ = channel.Close()
|
||||
return nil, xerrors.Errorf("remote dial error: %v", res.Error)
|
||||
}
|
||||
|
||||
return channel.NetConn(), nil
|
||||
}
|
||||
|
||||
func (c *WebRTCConn) Close() error {
|
||||
_ = c.Negotiator.DRPCConn().Close()
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
type TailnetConn struct {
|
||||
*tailnet.Conn
|
||||
CloseFunc func()
|
||||
}
|
||||
|
||||
func (c *TailnetConn) Ping() (time.Duration, error) {
|
||||
errCh := make(chan error, 1)
|
||||
durCh := make(chan time.Duration, 1)
|
||||
c.Conn.Ping(tailnetIP, tailcfg.PingICMP, func(pr *ipnstate.PingResult) {
|
||||
if pr.Err != "" {
|
||||
errCh <- xerrors.New(pr.Err)
|
||||
return
|
||||
}
|
||||
durCh <- time.Duration(pr.LatencySeconds * float64(time.Second))
|
||||
})
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return 0, err
|
||||
case dur := <-durCh:
|
||||
return dur, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TailnetConn) CloseWithError(_ error) error {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
func (c *TailnetConn) Close() error {
|
||||
if c.CloseFunc != nil {
|
||||
c.CloseFunc()
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
type reconnectingPTYInit struct {
|
||||
ID string
|
||||
Height uint16
|
||||
Width uint16
|
||||
Command string
|
||||
}
|
||||
|
||||
func (c *TailnetConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) {
|
||||
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetReconnectingPTYPort)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := json.Marshal(reconnectingPTYInit{
|
||||
ID: id,
|
||||
Height: height,
|
||||
Width: width,
|
||||
Command: command,
|
||||
})
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
data = append(make([]byte, 2), data...)
|
||||
binary.LittleEndian.PutUint16(data, uint16(len(data)-2))
|
||||
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (c *TailnetConn) SSH() (net.Conn, error) {
|
||||
return c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSSHPort)))
|
||||
}
|
||||
|
||||
// SSHClient calls SSH to create a client that uses a weak cipher
|
||||
// for high throughput.
|
||||
func (c *TailnetConn) SSHClient() (*ssh.Client, error) {
|
||||
netConn, err := c.SSH()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh: %w", err)
|
||||
}
|
||||
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
|
||||
// SSH host validation isn't helpful, because obtaining a peer
|
||||
// connection already signifies user-intent to dial a workspace.
|
||||
// #nosec
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("ssh conn: %w", err)
|
||||
}
|
||||
return ssh.NewClient(sshConn, channels, requests), nil
|
||||
}
|
||||
|
||||
func (c *TailnetConn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
|
||||
speedConn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(tailnetIP, uint16(tailnetSpeedtestPort)))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("dial speedtest: %w", err)
|
||||
}
|
||||
results, err := speedtest.RunClientWithConn(direction, duration, speedConn)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("run speedtest: %w", err)
|
||||
}
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (c *TailnetConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
_, rawPort, _ := net.SplitHostPort(addr)
|
||||
port, _ := strconv.Atoi(rawPort)
|
||||
ipp := netip.AddrPortFrom(tailnetIP, uint16(port))
|
||||
if network == "udp" {
|
||||
return c.Conn.DialContextUDP(ctx, ipp)
|
||||
}
|
||||
return c.Conn.DialContextTCP(ctx, ipp)
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
//go:build linux || (windows && amd64)
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/cakturk/go-netstat/netstat"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.ListeningPort, error) {
|
||||
lp.mut.Lock()
|
||||
defer lp.mut.Unlock()
|
||||
|
||||
if time.Since(lp.mtime) < time.Second {
|
||||
// copy
|
||||
ports := make([]codersdk.ListeningPort, len(lp.ports))
|
||||
copy(ports, lp.ports)
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool {
|
||||
return s.State == netstat.Listen
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("scan listening ports: %w", err)
|
||||
}
|
||||
|
||||
seen := make(map[uint16]struct{}, len(tabs))
|
||||
ports := []codersdk.ListeningPort{}
|
||||
for _, tab := range tabs {
|
||||
if tab.LocalAddr == nil || tab.LocalAddr.Port < uint16(codersdk.MinimumListeningPort) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't include ports that we've already seen. This can happen on
|
||||
// Windows, and maybe on Linux if you're using a shared listener socket.
|
||||
if _, ok := seen[tab.LocalAddr.Port]; ok {
|
||||
continue
|
||||
}
|
||||
seen[tab.LocalAddr.Port] = struct{}{}
|
||||
|
||||
procName := ""
|
||||
if tab.Process != nil {
|
||||
procName = tab.Process.Name
|
||||
}
|
||||
ports = append(ports, codersdk.ListeningPort{
|
||||
ProcessName: procName,
|
||||
Network: codersdk.ListeningPortNetworkTCP,
|
||||
Port: tab.LocalAddr.Port,
|
||||
})
|
||||
}
|
||||
|
||||
lp.ports = ports
|
||||
lp.mtime = time.Now()
|
||||
|
||||
// copy
|
||||
ports = make([]codersdk.ListeningPort, len(lp.ports))
|
||||
copy(ports, lp.ports)
|
||||
return ports, nil
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !linux && !(windows && amd64)
|
||||
|
||||
package agent
|
||||
|
||||
import "github.com/coder/coder/codersdk"
|
||||
|
||||
func (lp *listeningPortsHandler) getListeningPorts() ([]codersdk.ListeningPort, error) {
|
||||
// Can't scan for ports on non-linux or non-windows_amd64 systems at the
|
||||
// moment. The UI will not show any "no ports found" message to the user, so
|
||||
// the user won't suspect a thing.
|
||||
return []codersdk.ListeningPort{}, nil
|
||||
}
|
||||
+4
-3
@@ -7,6 +7,7 @@ import (
|
||||
"sync/atomic"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// statsConn wraps a net.Conn with statistics.
|
||||
@@ -40,8 +41,8 @@ type Stats struct {
|
||||
TxBytes int64 `json:"tx_bytes"`
|
||||
}
|
||||
|
||||
func (s *Stats) Copy() *Stats {
|
||||
return &Stats{
|
||||
func (s *Stats) Copy() *codersdk.AgentStats {
|
||||
return &codersdk.AgentStats{
|
||||
NumConns: atomic.LoadInt64(&s.NumConns),
|
||||
RxBytes: atomic.LoadInt64(&s.RxBytes),
|
||||
TxBytes: atomic.LoadInt64(&s.TxBytes),
|
||||
@@ -63,5 +64,5 @@ func (s *Stats) wrapConn(conn net.Conn) net.Conn {
|
||||
type StatsReporter func(
|
||||
ctx context.Context,
|
||||
log slog.Logger,
|
||||
stats func() *Stats,
|
||||
stats func() *codersdk.AgentStats,
|
||||
) (io.Closer, error)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func (*agent) statisticsHandler() http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.Response{
|
||||
Message: "Hello from the agent!",
|
||||
})
|
||||
})
|
||||
|
||||
lp := &listeningPortsHandler{}
|
||||
r.Get("/api/v0/listening-ports", lp.handler)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
type listeningPortsHandler struct {
|
||||
mut sync.Mutex
|
||||
ports []codersdk.ListeningPort
|
||||
mtime time.Time
|
||||
}
|
||||
|
||||
// handler returns a list of listening ports. This is tested by coderd's
|
||||
// TestWorkspaceAgentListeningPorts test.
|
||||
func (lp *listeningPortsHandler) handler(rw http.ResponseWriter, r *http.Request) {
|
||||
ports, err := lp.getListeningPorts()
|
||||
if err != nil {
|
||||
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Could not scan for listening ports.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.ListeningPortsResponse{
|
||||
Ports: ports,
|
||||
})
|
||||
}
|
||||
+16
-9
@@ -32,7 +32,6 @@ func workspaceAgent() *cobra.Command {
|
||||
pprofEnabled bool
|
||||
pprofAddress string
|
||||
noReap bool
|
||||
wireguard bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "agent",
|
||||
@@ -169,6 +168,18 @@ func workspaceAgent() *cobra.Command {
|
||||
}
|
||||
}
|
||||
|
||||
retryCtx, cancelRetry := context.WithTimeout(cmd.Context(), time.Hour)
|
||||
defer cancelRetry()
|
||||
for retrier := retry.New(100*time.Millisecond, 5*time.Second); retrier.Wait(retryCtx); {
|
||||
err := client.PostWorkspaceAgentVersion(retryCtx, version)
|
||||
if err != nil {
|
||||
logger.Warn(retryCtx, "post agent version: %w", slog.Error(err), slog.F("version", version))
|
||||
continue
|
||||
}
|
||||
logger.Info(retryCtx, "updated agent version", slog.F("version", version))
|
||||
break
|
||||
}
|
||||
|
||||
executablePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("getting os executable: %w", err)
|
||||
@@ -178,21 +189,18 @@ func workspaceAgent() *cobra.Command {
|
||||
return xerrors.Errorf("add executable to $PATH: %w", err)
|
||||
}
|
||||
|
||||
if err := client.PostWorkspaceAgentVersion(cmd.Context(), version); err != nil {
|
||||
logger.Error(cmd.Context(), "post agent version: %w", slog.Error(err), slog.F("version", version))
|
||||
}
|
||||
|
||||
closer := agent.New(agent.Options{
|
||||
FetchMetadata: client.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: client.ListenWorkspaceAgent,
|
||||
Logger: logger,
|
||||
EnvironmentVariables: map[string]string{
|
||||
// Override the "CODER_AGENT_TOKEN" variable in all
|
||||
// shells so "gitssh" works!
|
||||
"CODER_AGENT_TOKEN": client.SessionToken,
|
||||
},
|
||||
CoordinatorDialer: client.ListenWorkspaceAgentTailnet,
|
||||
StatsReporter: client.AgentReportStats,
|
||||
CoordinatorDialer: client.ListenWorkspaceAgentTailnet,
|
||||
StatsReporter: client.AgentReportStats,
|
||||
WorkspaceAgentApps: client.WorkspaceAgentApps,
|
||||
PostWorkspaceAgentAppHealth: client.PostWorkspaceAgentAppHealth,
|
||||
})
|
||||
<-cmd.Context().Done()
|
||||
return closer.Close()
|
||||
@@ -203,6 +211,5 @@ func workspaceAgent() *cobra.Command {
|
||||
cliflag.BoolVarP(cmd.Flags(), &pprofEnabled, "pprof-enable", "", "CODER_AGENT_PPROF_ENABLE", false, "Enable serving pprof metrics on the address defined by --pprof-address.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &noReap, "no-reap", "", "", false, "Do not start a process reaper.")
|
||||
cliflag.StringVarP(cmd.Flags(), &pprofAddress, "pprof-address", "", "CODER_AGENT_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_AGENT_WIREGUARD", true, "Whether to start the Wireguard interface.")
|
||||
return cmd
|
||||
}
|
||||
|
||||
+26
-16
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestWorkspaceAgent(t *testing.T) {
|
||||
@@ -47,7 +48,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String(), "--wireguard=false")
|
||||
cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
errC := make(chan error)
|
||||
@@ -57,17 +58,20 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
ctx := context.WithValue(ctx, "azure-client", metadataClient)
|
||||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
workspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
resources := workspace.LatestBuild.Resources
|
||||
if assert.NotEmpty(t, workspace.LatestBuild.Resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := dialer.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
cancelFunc()
|
||||
err = <-errC
|
||||
require.NoError(t, err)
|
||||
@@ -105,7 +109,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String(), "--wireguard=false")
|
||||
cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
errC := make(chan error)
|
||||
@@ -115,17 +119,20 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
ctx := context.WithValue(ctx, "aws-client", metadataClient)
|
||||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
workspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
resources := workspace.LatestBuild.Resources
|
||||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := dialer.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
cancelFunc()
|
||||
err = <-errC
|
||||
require.NoError(t, err)
|
||||
@@ -163,7 +170,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String(), "--wireguard=false")
|
||||
cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
errC := make(chan error)
|
||||
@@ -173,17 +180,20 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
ctx := context.WithValue(ctx, "gcp-client", metadata)
|
||||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
workspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
resources := workspace.LatestBuild.Resources
|
||||
if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) {
|
||||
assert.NotEmpty(t, resources[0].Agents[0].Version)
|
||||
}
|
||||
dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := dialer.Ping(ctx)
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
cancelFunc()
|
||||
err = <-errC
|
||||
require.NoError(t, err)
|
||||
|
||||
+15
-1
@@ -18,6 +18,8 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
)
|
||||
|
||||
// IsSetBool returns the value of the boolean flag if it is set.
|
||||
@@ -61,6 +63,18 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string
|
||||
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
|
||||
}
|
||||
|
||||
func StringArray(flagset *pflag.FlagSet, name, shorthand, env string, def []string, usage string) {
|
||||
v, ok := os.LookupEnv(env)
|
||||
if !ok || v == "" {
|
||||
if v == "" {
|
||||
def = []string{}
|
||||
} else {
|
||||
def = strings.Split(v, ",")
|
||||
}
|
||||
}
|
||||
flagset.StringArrayP(name, shorthand, def, fmtUsage(usage, env))
|
||||
}
|
||||
|
||||
func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
|
||||
val, ok := os.LookupEnv(env)
|
||||
if ok {
|
||||
@@ -164,7 +178,7 @@ func fmtUsage(u string, env string) string {
|
||||
if strings.HasSuffix(u, ".") {
|
||||
dot = ""
|
||||
}
|
||||
u = fmt.Sprintf("%s%s\nConsumes $%s", u, dot, env)
|
||||
u = fmt.Sprintf("%s%s\n"+cliui.Styles.Placeholder.Render("Consumes $%s"), u, dot, env)
|
||||
}
|
||||
|
||||
return u
|
||||
|
||||
+1
-1
@@ -48,7 +48,7 @@ var Styles = struct {
|
||||
Field: defaultStyles.Code.Copy().Foreground(lipgloss.AdaptiveColor{Light: "#000000", Dark: "#FFFFFF"}),
|
||||
Keyword: defaultStyles.Keyword,
|
||||
Paragraph: defaultStyles.Paragraph,
|
||||
Placeholder: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#585858", Dark: "#005fff"}),
|
||||
Placeholder: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#585858", Dark: "#4d46b3"}),
|
||||
Prompt: defaultStyles.Prompt.Foreground(lipgloss.AdaptiveColor{Light: "#9B9B9B", Dark: "#5C5C5C"}),
|
||||
FocusedPrompt: defaultStyles.FocusedPrompt.Foreground(lipgloss.Color("#651fff")),
|
||||
Fuchsia: defaultStyles.SelectedMenuItem.Copy(),
|
||||
|
||||
@@ -22,7 +22,7 @@ func WorkspaceBuild(ctx context.Context, writer io.Writer, client *codersdk.Clie
|
||||
build, err := client.WorkspaceBuild(ctx, build)
|
||||
return build.Job, err
|
||||
},
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) {
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
|
||||
return client.WorkspaceBuildLogsAfter(ctx, build, before)
|
||||
},
|
||||
})
|
||||
@@ -31,7 +31,7 @@ func WorkspaceBuild(ctx context.Context, writer io.Writer, client *codersdk.Clie
|
||||
type ProvisionerJobOptions struct {
|
||||
Fetch func() (codersdk.ProvisionerJob, error)
|
||||
Cancel func() error
|
||||
Logs func() (<-chan codersdk.ProvisionerJobLog, error)
|
||||
Logs func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error)
|
||||
|
||||
FetchInterval time.Duration
|
||||
// Verbose determines whether debug and trace logs will be shown.
|
||||
@@ -132,10 +132,11 @@ func ProvisionerJob(ctx context.Context, writer io.Writer, opts ProvisionerJobOp
|
||||
// The initial stage needs to print after the signal handler has been registered.
|
||||
printStage()
|
||||
|
||||
logs, err := opts.Logs()
|
||||
logs, closer, err := opts.Logs()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("logs: %w", err)
|
||||
}
|
||||
defer closer.Close()
|
||||
|
||||
var (
|
||||
// logOutput is where log output is written
|
||||
|
||||
@@ -2,6 +2,7 @@ package cliui_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
@@ -136,8 +137,10 @@ func newProvisionerJob(t *testing.T) provisionerJobTest {
|
||||
Cancel: func() error {
|
||||
return nil
|
||||
},
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) {
|
||||
return logs, nil
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
|
||||
return logs, closeFunc(func() error {
|
||||
return nil
|
||||
}), nil
|
||||
},
|
||||
})
|
||||
},
|
||||
@@ -164,3 +167,9 @@ func newProvisionerJob(t *testing.T) provisionerJobTest {
|
||||
PTY: ptty,
|
||||
}
|
||||
}
|
||||
|
||||
type closeFunc func() error
|
||||
|
||||
func (c closeFunc) Close() error {
|
||||
return c()
|
||||
}
|
||||
|
||||
+21
-2
@@ -153,10 +153,14 @@ func DisplayTable(out any, sort string, filterColumns []string) (string, error)
|
||||
// Special type formatting.
|
||||
switch val := v.(type) {
|
||||
case time.Time:
|
||||
v = val.Format(time.Stamp)
|
||||
v = val.Format(time.RFC3339)
|
||||
case *time.Time:
|
||||
if val != nil {
|
||||
v = val.Format(time.Stamp)
|
||||
v = val.Format(time.RFC3339)
|
||||
}
|
||||
case *int64:
|
||||
if val != nil {
|
||||
v = *val
|
||||
}
|
||||
case fmt.Stringer:
|
||||
if val != nil {
|
||||
@@ -305,3 +309,18 @@ func valueToTableMap(val reflect.Value) (map[string]any, error) {
|
||||
|
||||
return row, nil
|
||||
}
|
||||
|
||||
// TableHeaders returns the table header names of all
|
||||
// fields in tSlice. tSlice must be a slice of some type.
|
||||
func TableHeaders(tSlice any) ([]string, error) {
|
||||
v := reflect.Indirect(reflect.ValueOf(tSlice))
|
||||
rawHeaders, err := typeToTableHeaders(v.Type().Elem())
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("type to table headers: %w", err)
|
||||
}
|
||||
out := make([]string, 0, len(rawHeaders))
|
||||
for _, hdr := range rawHeaders {
|
||||
out = append(out, strings.Replace(hdr, " ", "_", -1))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
+33
-11
@@ -131,10 +131,10 @@ func Test_DisplayTable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expected := `
|
||||
NAME AGE ROLES SUB 1 NAME SUB 1 AGE SUB 2 NAME SUB 2 AGE SUB 3 INNER NAME SUB 3 INNER AGE SUB 4 TIME TIME PTR
|
||||
foo 10 [a b c] foo1 11 foo2 12 foo3 13 {foo4 14 } Aug 2 15:49:10 Aug 2 15:49:10
|
||||
bar 20 [a] bar1 21 <nil> <nil> bar3 23 {bar4 24 } Aug 2 15:49:10 <nil>
|
||||
baz 30 [] baz1 31 <nil> <nil> baz3 33 {baz4 34 } Aug 2 15:49:10 <nil>
|
||||
NAME AGE ROLES SUB 1 NAME SUB 1 AGE SUB 2 NAME SUB 2 AGE SUB 3 INNER NAME SUB 3 INNER AGE SUB 4 TIME TIME PTR
|
||||
foo 10 [a b c] foo1 11 foo2 12 foo3 13 {foo4 14 } 2022-08-02T15:49:10Z 2022-08-02T15:49:10Z
|
||||
bar 20 [a] bar1 21 <nil> <nil> bar3 23 {bar4 24 } 2022-08-02T15:49:10Z <nil>
|
||||
baz 30 [] baz1 31 <nil> <nil> baz3 33 {baz4 34 } 2022-08-02T15:49:10Z <nil>
|
||||
`
|
||||
|
||||
// Test with non-pointer values.
|
||||
@@ -158,10 +158,10 @@ baz 30 [] baz1 31 <nil> <nil> baz3
|
||||
t.Parallel()
|
||||
|
||||
expected := `
|
||||
NAME AGE ROLES SUB 1 NAME SUB 1 AGE SUB 2 NAME SUB 2 AGE SUB 3 INNER NAME SUB 3 INNER AGE SUB 4 TIME TIME PTR
|
||||
bar 20 [a] bar1 21 <nil> <nil> bar3 23 {bar4 24 } Aug 2 15:49:10 <nil>
|
||||
baz 30 [] baz1 31 <nil> <nil> baz3 33 {baz4 34 } Aug 2 15:49:10 <nil>
|
||||
foo 10 [a b c] foo1 11 foo2 12 foo3 13 {foo4 14 } Aug 2 15:49:10 Aug 2 15:49:10
|
||||
NAME AGE ROLES SUB 1 NAME SUB 1 AGE SUB 2 NAME SUB 2 AGE SUB 3 INNER NAME SUB 3 INNER AGE SUB 4 TIME TIME PTR
|
||||
bar 20 [a] bar1 21 <nil> <nil> bar3 23 {bar4 24 } 2022-08-02T15:49:10Z <nil>
|
||||
baz 30 [] baz1 31 <nil> <nil> baz3 33 {baz4 34 } 2022-08-02T15:49:10Z <nil>
|
||||
foo 10 [a b c] foo1 11 foo2 12 foo3 13 {foo4 14 } 2022-08-02T15:49:10Z 2022-08-02T15:49:10Z
|
||||
`
|
||||
|
||||
out, err := cliui.DisplayTable(in, "name", nil)
|
||||
@@ -175,9 +175,9 @@ foo 10 [a b c] foo1 11 foo2 12 foo3
|
||||
|
||||
expected := `
|
||||
NAME SUB 1 NAME SUB 3 INNER NAME TIME
|
||||
foo foo1 foo3 Aug 2 15:49:10
|
||||
bar bar1 bar3 Aug 2 15:49:10
|
||||
baz baz1 baz3 Aug 2 15:49:10
|
||||
foo foo1 foo3 2022-08-02T15:49:10Z
|
||||
bar bar1 bar3 2022-08-02T15:49:10Z
|
||||
baz baz1 baz3 2022-08-02T15:49:10Z
|
||||
`
|
||||
|
||||
out, err := cliui.DisplayTable(in, "", []string{"name", "sub_1_name", "sub_3 inner name", "time"})
|
||||
@@ -327,6 +327,28 @@ baz baz1 baz3 Aug 2 15:49:10
|
||||
})
|
||||
}
|
||||
|
||||
func Test_TableHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := []tableTest1{}
|
||||
expectedFields := []string{
|
||||
"name",
|
||||
"age",
|
||||
"roles",
|
||||
"sub_1_name",
|
||||
"sub_1_age",
|
||||
"sub_2_name",
|
||||
"sub_2_age",
|
||||
"sub_3_inner_name",
|
||||
"sub_3_inner_age",
|
||||
"sub_4",
|
||||
"time",
|
||||
"time_ptr",
|
||||
}
|
||||
headers, err := cliui.TableHeaders(s)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, expectedFields, headers)
|
||||
}
|
||||
|
||||
// compareTables normalizes the incoming table lines
|
||||
func compareTables(t *testing.T, expected, out string) {
|
||||
t.Helper()
|
||||
|
||||
@@ -13,6 +13,11 @@ func (r Root) Session() File {
|
||||
return File(filepath.Join(string(r), "session"))
|
||||
}
|
||||
|
||||
// ReplicaID is a unique identifier for the Coder server.
|
||||
func (r Root) ReplicaID() File {
|
||||
return File(filepath.Join(string(r), "replica_id"))
|
||||
}
|
||||
|
||||
func (r Root) URL() File {
|
||||
return File(filepath.Join(string(r), "url"))
|
||||
}
|
||||
|
||||
+3
-11
@@ -139,12 +139,11 @@ func configSSH() *cobra.Command {
|
||||
usePreviousOpts bool
|
||||
dryRun bool
|
||||
skipProxyCommand bool
|
||||
wireguard bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "config-ssh",
|
||||
Short: "Populate your SSH config with Host entries for all of your workspaces",
|
||||
Short: "Add an SSH Host entry for your workspaces \"ssh coder.workspace\"",
|
||||
Example: formatExamples(
|
||||
example{
|
||||
Description: "You can use -o (or --ssh-option) so set SSH options to be used for all your workspaces",
|
||||
@@ -289,15 +288,11 @@ func configSSH() *cobra.Command {
|
||||
"\tLogLevel ERROR",
|
||||
)
|
||||
if !skipProxyCommand {
|
||||
wgArg := ""
|
||||
if wireguard {
|
||||
wgArg = "--wireguard "
|
||||
}
|
||||
configOptions = append(
|
||||
configOptions,
|
||||
fmt.Sprintf(
|
||||
"\tProxyCommand %s --global-config %s ssh %s--stdio %s",
|
||||
escapedCoderBinary, escapedGlobalConfig, wgArg, hostname,
|
||||
"\tProxyCommand %s --global-config %s ssh --stdio %s",
|
||||
escapedCoderBinary, escapedGlobalConfig, hostname,
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -374,9 +369,6 @@ func configSSH() *cobra.Command {
|
||||
cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.")
|
||||
_ = cmd.Flags().MarkHidden("skip-proxy-command")
|
||||
cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_CONFIG_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
|
||||
_ = cmd.Flags().MarkHidden("wireguard")
|
||||
|
||||
cliui.AllowSkipPrompt(cmd)
|
||||
|
||||
return cmd
|
||||
|
||||
+20
-6
@@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -106,14 +107,13 @@ func TestConfigSSH(t *testing.T) {
|
||||
agentClient.SessionToken = authToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: client.ListenWorkspaceAgentTailnet,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
defer func() {
|
||||
_ = agentCloser.Close()
|
||||
}()
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
agentConn, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer agentConn.Close()
|
||||
@@ -123,17 +123,28 @@ func TestConfigSSH(t *testing.T) {
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
copyDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(copyDone)
|
||||
var wg sync.WaitGroup
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
break
|
||||
}
|
||||
ssh, err := agentConn.SSH()
|
||||
assert.NoError(t, err)
|
||||
go io.Copy(conn, ssh)
|
||||
go io.Copy(ssh, conn)
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(conn, ssh)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(ssh, conn)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
sshConfigFile := sshConfigFileName(t)
|
||||
@@ -178,6 +189,9 @@ func TestConfigSSH(t *testing.T) {
|
||||
data, err := sshCmd.Output()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test", strings.TrimSpace(string(data)))
|
||||
|
||||
_ = listener.Close()
|
||||
<-copyDone
|
||||
}
|
||||
|
||||
func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) {
|
||||
|
||||
+12
-11
@@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
@@ -25,7 +26,7 @@ func create() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "create [name]",
|
||||
Short: "Create a workspace from a template",
|
||||
Short: "Create a workspace",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
if err != nil {
|
||||
@@ -72,7 +73,7 @@ func create() *cobra.Command {
|
||||
}
|
||||
|
||||
slices.SortFunc(templates, func(a, b codersdk.Template) bool {
|
||||
return a.WorkspaceOwnerCount > b.WorkspaceOwnerCount
|
||||
return a.ActiveUserCount > b.ActiveUserCount
|
||||
})
|
||||
|
||||
templateNames := make([]string, 0, len(templates))
|
||||
@@ -81,13 +82,13 @@ func create() *cobra.Command {
|
||||
for _, template := range templates {
|
||||
templateName := template.Name
|
||||
|
||||
if template.WorkspaceOwnerCount > 0 {
|
||||
developerText := "developer"
|
||||
if template.WorkspaceOwnerCount != 1 {
|
||||
developerText = "developers"
|
||||
}
|
||||
|
||||
templateName += cliui.Styles.Placeholder.Render(fmt.Sprintf(" (used by %d %s)", template.WorkspaceOwnerCount, developerText))
|
||||
if template.ActiveUserCount > 0 {
|
||||
templateName += cliui.Styles.Placeholder.Render(
|
||||
fmt.Sprintf(
|
||||
" (used by %s)",
|
||||
formatActiveDevelopers(template.ActiveUserCount),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
templateNames = append(templateNames, templateName)
|
||||
@@ -139,7 +140,7 @@ func create() *cobra.Command {
|
||||
}
|
||||
|
||||
after := time.Now()
|
||||
workspace, err := client.CreateWorkspace(cmd.Context(), organization.ID, codersdk.CreateWorkspaceRequest{
|
||||
workspace, err := client.CreateWorkspace(cmd.Context(), organization.ID, codersdk.Me, codersdk.CreateWorkspaceRequest{
|
||||
TemplateID: template.ID,
|
||||
Name: workspaceName,
|
||||
AutostartSchedule: schedSpec,
|
||||
@@ -253,7 +254,7 @@ PromptParamLoop:
|
||||
Cancel: func() error {
|
||||
return client.CancelTemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID)
|
||||
},
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) {
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
|
||||
return client.TemplateVersionDryRunLogsAfter(cmd.Context(), templateVersion.ID, dryRun.ID, after)
|
||||
},
|
||||
// Don't show log output for the dry-run unless there's an error.
|
||||
|
||||
@@ -0,0 +1,511 @@
|
||||
package deployment
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/spf13/pflag"
|
||||
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
const (
|
||||
secretValue = "********"
|
||||
)
|
||||
|
||||
func Flags() *codersdk.DeploymentFlags {
|
||||
return &codersdk.DeploymentFlags{
|
||||
AccessURL: &codersdk.StringFlag{
|
||||
Name: "Access URL",
|
||||
Flag: "access-url",
|
||||
EnvVar: "CODER_ACCESS_URL",
|
||||
Description: "External URL to access your deployment. This must be accessible by all provisioned workspaces.",
|
||||
},
|
||||
WildcardAccessURL: &codersdk.StringFlag{
|
||||
Name: "Wildcard Address URL",
|
||||
Flag: "wildcard-access-url",
|
||||
EnvVar: "CODER_WILDCARD_ACCESS_URL",
|
||||
Description: `Specifies the wildcard hostname to use for workspace applications in the form "*.example.com" or "*-suffix.example.com". Ports or schemes should not be included. The scheme will be copied from the access URL.`,
|
||||
},
|
||||
Address: &codersdk.StringFlag{
|
||||
Name: "Bind Address",
|
||||
Flag: "address",
|
||||
EnvVar: "CODER_ADDRESS",
|
||||
Shorthand: "a",
|
||||
Description: "Bind address of the server.",
|
||||
Default: "127.0.0.1:3000",
|
||||
},
|
||||
AutobuildPollInterval: &codersdk.DurationFlag{
|
||||
Name: "Autobuild Poll Interval",
|
||||
Flag: "autobuild-poll-interval",
|
||||
EnvVar: "CODER_AUTOBUILD_POLL_INTERVAL",
|
||||
Description: "Interval to poll for scheduled workspace builds.",
|
||||
Hidden: true,
|
||||
Default: time.Minute,
|
||||
},
|
||||
DerpServerEnable: &codersdk.BoolFlag{
|
||||
Name: "DERP Server Enabled",
|
||||
Flag: "derp-server-enable",
|
||||
EnvVar: "CODER_DERP_SERVER_ENABLE",
|
||||
Description: "Whether to enable or disable the embedded DERP relay server.",
|
||||
Default: true,
|
||||
},
|
||||
DerpServerRegionID: &codersdk.IntFlag{
|
||||
Name: "DERP Server Region ID",
|
||||
Flag: "derp-server-region-id",
|
||||
EnvVar: "CODER_DERP_SERVER_REGION_ID",
|
||||
Description: "Region ID to use for the embedded DERP server.",
|
||||
Default: 999,
|
||||
},
|
||||
DerpServerRegionCode: &codersdk.StringFlag{
|
||||
Name: "DERP Server Region Code",
|
||||
Flag: "derp-server-region-code",
|
||||
EnvVar: "CODER_DERP_SERVER_REGION_CODE",
|
||||
Description: "Region code to use for the embedded DERP server.",
|
||||
Default: "coder",
|
||||
},
|
||||
DerpServerRegionName: &codersdk.StringFlag{
|
||||
Name: "DERP Server Region Name",
|
||||
Flag: "derp-server-region-name",
|
||||
EnvVar: "CODER_DERP_SERVER_REGION_NAME",
|
||||
Description: "Region name that for the embedded DERP server.",
|
||||
Default: "Coder Embedded Relay",
|
||||
},
|
||||
DerpServerSTUNAddresses: &codersdk.StringArrayFlag{
|
||||
Name: "DERP Server STUN Addresses",
|
||||
Flag: "derp-server-stun-addresses",
|
||||
EnvVar: "CODER_DERP_SERVER_STUN_ADDRESSES",
|
||||
Description: "Addresses for STUN servers to establish P2P connections. Set empty to disable P2P connections.",
|
||||
Default: []string{"stun.l.google.com:19302"},
|
||||
},
|
||||
DerpServerRelayAddress: &codersdk.StringFlag{
|
||||
Name: "DERP Server Relay Address",
|
||||
Flag: "derp-server-relay-address",
|
||||
EnvVar: "CODER_DERP_SERVER_RELAY_URL",
|
||||
Description: "An HTTP address that is accessible by other replicas to relay DERP traffic. Required for high availability.",
|
||||
Enterprise: true,
|
||||
},
|
||||
DerpConfigURL: &codersdk.StringFlag{
|
||||
Name: "DERP Config URL",
|
||||
Flag: "derp-config-url",
|
||||
EnvVar: "CODER_DERP_CONFIG_URL",
|
||||
Description: "URL to fetch a DERP mapping on startup. See: https://tailscale.com/kb/1118/custom-derp-servers/",
|
||||
},
|
||||
DerpConfigPath: &codersdk.StringFlag{
|
||||
Name: "DERP Config Path",
|
||||
Flag: "derp-config-path",
|
||||
EnvVar: "CODER_DERP_CONFIG_PATH",
|
||||
Description: "Path to read a DERP mapping from. See: https://tailscale.com/kb/1118/custom-derp-servers/",
|
||||
},
|
||||
PromEnabled: &codersdk.BoolFlag{
|
||||
Name: "Prometheus Enabled",
|
||||
Flag: "prometheus-enable",
|
||||
EnvVar: "CODER_PROMETHEUS_ENABLE",
|
||||
Description: "Serve prometheus metrics on the address defined by `prometheus-address`.",
|
||||
},
|
||||
PromAddress: &codersdk.StringFlag{
|
||||
Name: "Prometheus Address",
|
||||
Flag: "prometheus-address",
|
||||
EnvVar: "CODER_PROMETHEUS_ADDRESS",
|
||||
Description: "The bind address to serve prometheus metrics.",
|
||||
Default: "127.0.0.1:2112",
|
||||
},
|
||||
PprofEnabled: &codersdk.BoolFlag{
|
||||
Name: "pprof Enabled",
|
||||
Flag: "pprof-enable",
|
||||
EnvVar: "CODER_PPROF_ENABLE",
|
||||
Description: "Serve pprof metrics on the address defined by `pprof-address`.",
|
||||
},
|
||||
PprofAddress: &codersdk.StringFlag{
|
||||
Name: "pprof Address",
|
||||
Flag: "pprof-address",
|
||||
EnvVar: "CODER_PPROF_ADDRESS",
|
||||
Description: "The bind address to serve pprof.",
|
||||
Default: "127.0.0.1:6060",
|
||||
},
|
||||
CacheDir: &codersdk.StringFlag{
|
||||
Name: "Cache Directory",
|
||||
Flag: "cache-dir",
|
||||
EnvVar: "CODER_CACHE_DIRECTORY",
|
||||
Description: "The directory to cache temporary files. If unspecified and $CACHE_DIRECTORY is set, it will be used for compatibility with systemd.",
|
||||
Default: defaultCacheDir(),
|
||||
},
|
||||
InMemoryDatabase: &codersdk.BoolFlag{
|
||||
Name: "In-Memory Database",
|
||||
Flag: "in-memory",
|
||||
EnvVar: "CODER_INMEMORY",
|
||||
Description: "Controls whether data will be stored in an in-memory database.",
|
||||
Hidden: true,
|
||||
},
|
||||
ProvisionerDaemonCount: &codersdk.IntFlag{
|
||||
Name: "Provisioner Daemons",
|
||||
Flag: "provisioner-daemons",
|
||||
EnvVar: "CODER_PROVISIONER_DAEMONS",
|
||||
Description: "Number of provisioner daemons to create on start. If builds are stuck in queued state for a long time, consider increasing this.",
|
||||
Default: 3,
|
||||
},
|
||||
PostgresURL: &codersdk.StringFlag{
|
||||
Name: "Postgres URL",
|
||||
Flag: "postgres-url",
|
||||
EnvVar: "CODER_PG_CONNECTION_URL",
|
||||
Description: "URL of a PostgreSQL database. If empty, PostgreSQL binaries will be downloaded from Maven (https://repo1.maven.org/maven2) and store all data in the config root. Access the built-in database with \"coder server postgres-builtin-url\"",
|
||||
Secret: true,
|
||||
},
|
||||
OAuth2GithubClientID: &codersdk.StringFlag{
|
||||
Name: "Oauth2 Github Client ID",
|
||||
Flag: "oauth2-github-client-id",
|
||||
EnvVar: "CODER_OAUTH2_GITHUB_CLIENT_ID",
|
||||
Description: "Client ID for Login with GitHub.",
|
||||
},
|
||||
OAuth2GithubClientSecret: &codersdk.StringFlag{
|
||||
Name: "Oauth2 Github Client Secret",
|
||||
Flag: "oauth2-github-client-secret",
|
||||
EnvVar: "CODER_OAUTH2_GITHUB_CLIENT_SECRET",
|
||||
Description: "Client secret for Login with GitHub.",
|
||||
Secret: true,
|
||||
},
|
||||
OAuth2GithubAllowedOrganizations: &codersdk.StringArrayFlag{
|
||||
Name: "Oauth2 Github Allowed Organizations",
|
||||
Flag: "oauth2-github-allowed-orgs",
|
||||
EnvVar: "CODER_OAUTH2_GITHUB_ALLOWED_ORGS",
|
||||
Description: "Organizations the user must be a member of to Login with GitHub.",
|
||||
Default: []string{},
|
||||
},
|
||||
OAuth2GithubAllowedTeams: &codersdk.StringArrayFlag{
|
||||
Name: "Oauth2 Github Allowed Teams",
|
||||
Flag: "oauth2-github-allowed-teams",
|
||||
EnvVar: "CODER_OAUTH2_GITHUB_ALLOWED_TEAMS",
|
||||
Description: "Teams inside organizations the user must be a member of to Login with GitHub. Structured as: <organization-name>/<team-slug>.",
|
||||
Default: []string{},
|
||||
},
|
||||
OAuth2GithubAllowSignups: &codersdk.BoolFlag{
|
||||
Name: "Oauth2 Github Allow Signups",
|
||||
Flag: "oauth2-github-allow-signups",
|
||||
EnvVar: "CODER_OAUTH2_GITHUB_ALLOW_SIGNUPS",
|
||||
Description: "Whether new users can sign up with GitHub.",
|
||||
},
|
||||
OAuth2GithubEnterpriseBaseURL: &codersdk.StringFlag{
|
||||
Name: "Oauth2 Github Enterprise Base URL",
|
||||
Flag: "oauth2-github-enterprise-base-url",
|
||||
EnvVar: "CODER_OAUTH2_GITHUB_ENTERPRISE_BASE_URL",
|
||||
Description: "Base URL of a GitHub Enterprise deployment to use for Login with GitHub.",
|
||||
},
|
||||
OIDCAllowSignups: &codersdk.BoolFlag{
|
||||
Name: "OIDC Allow Signups",
|
||||
Flag: "oidc-allow-signups",
|
||||
EnvVar: "CODER_OIDC_ALLOW_SIGNUPS",
|
||||
Description: "Whether new users can sign up with OIDC.",
|
||||
Default: true,
|
||||
},
|
||||
OIDCClientID: &codersdk.StringFlag{
|
||||
Name: "OIDC Client ID",
|
||||
Flag: "oidc-client-id",
|
||||
EnvVar: "CODER_OIDC_CLIENT_ID",
|
||||
Description: "Client ID to use for Login with OIDC.",
|
||||
},
|
||||
OIDCClientSecret: &codersdk.StringFlag{
|
||||
Name: "OIDC Client Secret",
|
||||
Flag: "oidc-client-secret",
|
||||
EnvVar: "CODER_OIDC_CLIENT_SECRET",
|
||||
Description: "Client secret to use for Login with OIDC.",
|
||||
Secret: true,
|
||||
},
|
||||
OIDCEmailDomain: &codersdk.StringFlag{
|
||||
Name: "OIDC Email Domain",
|
||||
Flag: "oidc-email-domain",
|
||||
EnvVar: "CODER_OIDC_EMAIL_DOMAIN",
|
||||
Description: "Email domain that clients logging in with OIDC must match.",
|
||||
},
|
||||
OIDCIssuerURL: &codersdk.StringFlag{
|
||||
Name: "OIDC Issuer URL",
|
||||
Flag: "oidc-issuer-url",
|
||||
EnvVar: "CODER_OIDC_ISSUER_URL",
|
||||
Description: "Issuer URL to use for Login with OIDC.",
|
||||
},
|
||||
OIDCScopes: &codersdk.StringArrayFlag{
|
||||
Name: "OIDC Scopes",
|
||||
Flag: "oidc-scopes",
|
||||
EnvVar: "CODER_OIDC_SCOPES",
|
||||
Description: "Scopes to grant when authenticating with OIDC.",
|
||||
Default: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
},
|
||||
TelemetryEnable: &codersdk.BoolFlag{
|
||||
Name: "Telemetry Enabled",
|
||||
Flag: "telemetry",
|
||||
EnvVar: "CODER_TELEMETRY",
|
||||
Description: "Whether telemetry is enabled or not. Coder collects anonymized usage data to help improve our product.",
|
||||
Default: flag.Lookup("test.v") == nil,
|
||||
},
|
||||
TelemetryTraceEnable: &codersdk.BoolFlag{
|
||||
Name: "Trace Telemetry Enabled",
|
||||
Flag: "telemetry-trace",
|
||||
EnvVar: "CODER_TELEMETRY_TRACE",
|
||||
Shorthand: "",
|
||||
Description: "Whether Opentelemetry traces are sent to Coder. Coder collects anonymized application tracing to help improve our product. Disabling telemetry also disables this option.",
|
||||
Default: flag.Lookup("test.v") == nil,
|
||||
},
|
||||
TelemetryURL: &codersdk.StringFlag{
|
||||
Name: "Telemetry URL",
|
||||
Flag: "telemetry-url",
|
||||
EnvVar: "CODER_TELEMETRY_URL",
|
||||
Description: "URL to send telemetry.",
|
||||
Hidden: true,
|
||||
Default: "https://telemetry.coder.com",
|
||||
},
|
||||
TLSEnable: &codersdk.BoolFlag{
|
||||
Name: "TLS Enabled",
|
||||
Flag: "tls-enable",
|
||||
EnvVar: "CODER_TLS_ENABLE",
|
||||
Description: "Whether TLS will be enabled.",
|
||||
},
|
||||
TLSCertFiles: &codersdk.StringArrayFlag{
|
||||
Name: "TLS Cert Files",
|
||||
Flag: "tls-cert-file",
|
||||
EnvVar: "CODER_TLS_CERT_FILE",
|
||||
Description: "Path to each certificate for TLS. It requires a PEM-encoded file. " +
|
||||
"To configure the listener to use a CA certificate, concatenate the primary certificate " +
|
||||
"and the CA certificate together. The primary certificate should appear first in the combined file.",
|
||||
Default: []string{},
|
||||
},
|
||||
TLSClientCAFile: &codersdk.StringFlag{
|
||||
Name: "TLS Client CA File",
|
||||
Flag: "tls-client-ca-file",
|
||||
EnvVar: "CODER_TLS_CLIENT_CA_FILE",
|
||||
Description: "PEM-encoded Certificate Authority file used for checking the authenticity of client",
|
||||
},
|
||||
TLSClientAuth: &codersdk.StringFlag{
|
||||
Name: "TLS Client Auth",
|
||||
Flag: "tls-client-auth",
|
||||
EnvVar: "CODER_TLS_CLIENT_AUTH",
|
||||
Description: `Policy the server will follow for TLS Client Authentication. ` +
|
||||
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify"`,
|
||||
Default: "request",
|
||||
},
|
||||
TLSKeyFiles: &codersdk.StringArrayFlag{
|
||||
Name: "TLS Key Files",
|
||||
Flag: "tls-key-file",
|
||||
EnvVar: "CODER_TLS_KEY_FILE",
|
||||
Description: "Paths to the private keys for each of the certificates. It requires a PEM-encoded file",
|
||||
Default: []string{},
|
||||
},
|
||||
TLSMinVersion: &codersdk.StringFlag{
|
||||
Name: "TLS Min Version",
|
||||
Flag: "tls-min-version",
|
||||
EnvVar: "CODER_TLS_MIN_VERSION",
|
||||
Description: `Minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`,
|
||||
Default: "tls12",
|
||||
},
|
||||
TraceEnable: &codersdk.BoolFlag{
|
||||
Name: "Trace Enabled",
|
||||
Flag: "trace",
|
||||
EnvVar: "CODER_TRACE",
|
||||
Description: "Whether application tracing data is collected.",
|
||||
},
|
||||
SecureAuthCookie: &codersdk.BoolFlag{
|
||||
Name: "Secure Auth Cookie",
|
||||
Flag: "secure-auth-cookie",
|
||||
EnvVar: "CODER_SECURE_AUTH_COOKIE",
|
||||
Description: "Controls if the 'Secure' property is set on browser session cookies",
|
||||
},
|
||||
SSHKeygenAlgorithm: &codersdk.StringFlag{
|
||||
Name: "SSH Keygen Algorithm",
|
||||
Flag: "ssh-keygen-algorithm",
|
||||
EnvVar: "CODER_SSH_KEYGEN_ALGORITHM",
|
||||
Description: "The algorithm to use for generating ssh keys. " +
|
||||
`Accepted values are "ed25519", "ecdsa", or "rsa4096"`,
|
||||
Default: "ed25519",
|
||||
},
|
||||
AutoImportTemplates: &codersdk.StringArrayFlag{
|
||||
Name: "Auto Import Templates",
|
||||
Flag: "auto-import-template",
|
||||
EnvVar: "CODER_TEMPLATE_AUTOIMPORT",
|
||||
Description: "Templates to auto-import. Available auto-importable templates are: kubernetes",
|
||||
Hidden: true,
|
||||
Default: []string{},
|
||||
},
|
||||
MetricsCacheRefreshInterval: &codersdk.DurationFlag{
|
||||
Name: "Metrics Cache Refresh Interval",
|
||||
Flag: "metrics-cache-refresh-interval",
|
||||
EnvVar: "CODER_METRICS_CACHE_REFRESH_INTERVAL",
|
||||
Description: "How frequently metrics are refreshed",
|
||||
Hidden: true,
|
||||
Default: time.Hour,
|
||||
},
|
||||
AgentStatRefreshInterval: &codersdk.DurationFlag{
|
||||
Name: "Agent Stats Refresh Interval",
|
||||
Flag: "agent-stats-refresh-interval",
|
||||
EnvVar: "CODER_AGENT_STATS_REFRESH_INTERVAL",
|
||||
Description: "How frequently agent stats are recorded",
|
||||
Hidden: true,
|
||||
Default: 10 * time.Minute,
|
||||
},
|
||||
Verbose: &codersdk.BoolFlag{
|
||||
Name: "Verbose Logging",
|
||||
Flag: "verbose",
|
||||
EnvVar: "CODER_VERBOSE",
|
||||
Shorthand: "v",
|
||||
Description: "Enables verbose logging.",
|
||||
},
|
||||
AuditLogging: &codersdk.BoolFlag{
|
||||
Name: "Audit Logging",
|
||||
Flag: "audit-logging",
|
||||
EnvVar: "CODER_AUDIT_LOGGING",
|
||||
Description: "Specifies whether audit logging is enabled.",
|
||||
Default: true,
|
||||
Enterprise: true,
|
||||
},
|
||||
BrowserOnly: &codersdk.BoolFlag{
|
||||
Name: "Browser Only",
|
||||
Flag: "browser-only",
|
||||
EnvVar: "CODER_BROWSER_ONLY",
|
||||
Description: "Whether Coder only allows connections to workspaces via the browser.",
|
||||
Enterprise: true,
|
||||
},
|
||||
SCIMAuthHeader: &codersdk.StringFlag{
|
||||
Name: "SCIM Authentication Header",
|
||||
Flag: "scim-auth-header",
|
||||
EnvVar: "CODER_SCIM_API_KEY",
|
||||
Description: "Enables SCIM and sets the authentication header for the built-in SCIM server. New users are automatically created with OIDC authentication.",
|
||||
Secret: true,
|
||||
Enterprise: true,
|
||||
},
|
||||
UserWorkspaceQuota: &codersdk.IntFlag{
|
||||
Name: "User Workspace Quota",
|
||||
Flag: "user-workspace-quota",
|
||||
EnvVar: "CODER_USER_WORKSPACE_QUOTA",
|
||||
Description: "Enables and sets a limit on how many workspaces each user can create.",
|
||||
Default: 0,
|
||||
Enterprise: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func RemoveSensitiveValues(df codersdk.DeploymentFlags) codersdk.DeploymentFlags {
|
||||
v := reflect.ValueOf(&df).Elem()
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
fv := v.Field(i)
|
||||
if vp, ok := fv.Interface().(*codersdk.StringFlag); ok {
|
||||
if vp.Secret && vp.Value != "" {
|
||||
// Make a copy and remove the value.
|
||||
v := *vp
|
||||
v.Value = secretValue
|
||||
fv.Set(reflect.ValueOf(&v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return df
|
||||
}
|
||||
|
||||
//nolint:revive
|
||||
func AttachFlags(flagset *pflag.FlagSet, df *codersdk.DeploymentFlags, enterprise bool) {
|
||||
v := reflect.ValueOf(df).Elem()
|
||||
t := v.Type()
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
fv := v.Field(i)
|
||||
fve := fv.Elem()
|
||||
e := fve.FieldByName("Enterprise").Bool()
|
||||
if e != enterprise {
|
||||
continue
|
||||
}
|
||||
if e {
|
||||
d := fve.FieldByName("Description").String()
|
||||
d += cliui.Styles.Keyword.Render(" This is an Enterprise feature. Contact sales@coder.com for licensing")
|
||||
fve.FieldByName("Description").SetString(d)
|
||||
}
|
||||
|
||||
switch v := fv.Interface().(type) {
|
||||
case *codersdk.StringFlag:
|
||||
StringFlag(flagset, v)
|
||||
case *codersdk.StringArrayFlag:
|
||||
StringArrayFlag(flagset, v)
|
||||
case *codersdk.IntFlag:
|
||||
IntFlag(flagset, v)
|
||||
case *codersdk.BoolFlag:
|
||||
BoolFlag(flagset, v)
|
||||
case *codersdk.DurationFlag:
|
||||
DurationFlag(flagset, v)
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown flag type: %T", v))
|
||||
}
|
||||
if fve.FieldByName("Hidden").Bool() {
|
||||
_ = flagset.MarkHidden(fve.FieldByName("Flag").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func StringFlag(flagset *pflag.FlagSet, fl *codersdk.StringFlag) {
|
||||
cliflag.StringVarP(flagset,
|
||||
&fl.Value,
|
||||
fl.Flag,
|
||||
fl.Shorthand,
|
||||
fl.EnvVar,
|
||||
fl.Default,
|
||||
fl.Description,
|
||||
)
|
||||
}
|
||||
|
||||
func BoolFlag(flagset *pflag.FlagSet, fl *codersdk.BoolFlag) {
|
||||
cliflag.BoolVarP(flagset,
|
||||
&fl.Value,
|
||||
fl.Flag,
|
||||
fl.Shorthand,
|
||||
fl.EnvVar,
|
||||
fl.Default,
|
||||
fl.Description,
|
||||
)
|
||||
}
|
||||
|
||||
func IntFlag(flagset *pflag.FlagSet, fl *codersdk.IntFlag) {
|
||||
cliflag.IntVarP(flagset,
|
||||
&fl.Value,
|
||||
fl.Flag,
|
||||
fl.Shorthand,
|
||||
fl.EnvVar,
|
||||
fl.Default,
|
||||
fl.Description,
|
||||
)
|
||||
}
|
||||
|
||||
func DurationFlag(flagset *pflag.FlagSet, fl *codersdk.DurationFlag) {
|
||||
cliflag.DurationVarP(flagset,
|
||||
&fl.Value,
|
||||
fl.Flag,
|
||||
fl.Shorthand,
|
||||
fl.EnvVar,
|
||||
fl.Default,
|
||||
fl.Description,
|
||||
)
|
||||
}
|
||||
|
||||
func StringArrayFlag(flagset *pflag.FlagSet, fl *codersdk.StringArrayFlag) {
|
||||
cliflag.StringArrayVarP(flagset,
|
||||
&fl.Value,
|
||||
fl.Flag,
|
||||
fl.Shorthand,
|
||||
fl.EnvVar,
|
||||
fl.Default,
|
||||
fl.Description,
|
||||
)
|
||||
}
|
||||
|
||||
func defaultCacheDir() string {
|
||||
defaultCacheDir, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
defaultCacheDir = os.TempDir()
|
||||
}
|
||||
if dir := os.Getenv("CACHE_DIRECTORY"); dir != "" {
|
||||
// For compatibility with systemd.
|
||||
defaultCacheDir = dir
|
||||
}
|
||||
|
||||
return filepath.Join(defaultCacheDir, "coder")
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package deployment_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cli/deployment"
|
||||
)
|
||||
|
||||
func TestFlags(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
df := deployment.Flags()
|
||||
fs := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
||||
deployment.AttachFlags(fs, df, false)
|
||||
|
||||
require.NotNil(t, fs.Lookup("access-url"))
|
||||
require.False(t, fs.Lookup("access-url").Hidden)
|
||||
require.True(t, fs.Lookup("telemetry-url").Hidden)
|
||||
require.NotEmpty(t, fs.Lookup("telemetry-url").DefValue)
|
||||
require.Nil(t, fs.Lookup("audit-logging"))
|
||||
|
||||
df = deployment.Flags()
|
||||
fs = pflag.NewFlagSet("test-enterprise", pflag.ContinueOnError)
|
||||
deployment.AttachFlags(fs, df, true)
|
||||
|
||||
require.Nil(t, fs.Lookup("access-url"))
|
||||
require.NotNil(t, fs.Lookup("audit-logging"))
|
||||
require.Contains(t, fs.Lookup("audit-logging").Usage, "This is an Enterprise feature")
|
||||
}
|
||||
+1
-1
@@ -22,7 +22,7 @@ func dotfiles() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "dotfiles [git_repo_url]",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Short: "Check out and install a dotfiles repository.",
|
||||
Short: "Checkout and install a dotfiles repository from a Git URL",
|
||||
Example: formatExamples(
|
||||
example{
|
||||
Description: "Check out and install a dotfiles repository without prompts",
|
||||
|
||||
+121
-4
@@ -1,9 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
@@ -13,16 +19,30 @@ import (
|
||||
)
|
||||
|
||||
func gitssh() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
cmd := &cobra.Command{
|
||||
Use: "gitssh",
|
||||
Hidden: true,
|
||||
Short: `Wraps the "ssh" command and uses the coder gitssh key for authentication`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
env := os.Environ()
|
||||
|
||||
// Catch interrupt signals to ensure the temporary private
|
||||
// key file is cleaned up on most cases.
|
||||
ctx, stop := signal.NotifyContext(ctx, interruptSignals...)
|
||||
defer stop()
|
||||
|
||||
// Early check so errors are reported immediately.
|
||||
identityFiles, err := parseIdentityFilesForHost(ctx, args, env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := createAgentClient(cmd)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create agent client: %w", err)
|
||||
}
|
||||
key, err := client.AgentGitSSHKey(cmd.Context())
|
||||
key, err := client.AgentGitSSHKey(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get agent git ssh token: %w", err)
|
||||
}
|
||||
@@ -44,8 +64,23 @@ func gitssh() *cobra.Command {
|
||||
return xerrors.Errorf("close temp gitsshkey file: %w", err)
|
||||
}
|
||||
|
||||
args = append([]string{"-i", privateKeyFile.Name()}, args...)
|
||||
c := exec.CommandContext(cmd.Context(), "ssh", args...)
|
||||
// Append our key, giving precedence to user keys. Note that
|
||||
// OpenSSH server are typically configured with MaxAuthTries
|
||||
// set to the default value of 6. This means that only the 6
|
||||
// first keys can be tried. However, we will assume that if
|
||||
// a user has configured 6+ keys for a host, they know what
|
||||
// they're doing. This behavior is critical if a server has
|
||||
// been configured with MaxAuthTries set to 1.
|
||||
identityFiles = append(identityFiles, privateKeyFile.Name())
|
||||
|
||||
var identityArgs []string
|
||||
for _, id := range identityFiles {
|
||||
identityArgs = append(identityArgs, "-i", id)
|
||||
}
|
||||
|
||||
args = append(identityArgs, args...)
|
||||
c := exec.CommandContext(ctx, "ssh", args...)
|
||||
c.Env = append(c.Env, env...)
|
||||
c.Stderr = cmd.ErrOrStderr()
|
||||
c.Stdout = cmd.OutOrStdout()
|
||||
c.Stdin = cmd.InOrStdin()
|
||||
@@ -69,4 +104,86 @@ func gitssh() *cobra.Command {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// fallbackIdentityFiles is the list of identity files SSH tries when
|
||||
// none have been defined for a host.
|
||||
var fallbackIdentityFiles = strings.Join([]string{
|
||||
"identityfile ~/.ssh/id_rsa",
|
||||
"identityfile ~/.ssh/id_dsa",
|
||||
"identityfile ~/.ssh/id_ecdsa",
|
||||
"identityfile ~/.ssh/id_ecdsa_sk",
|
||||
"identityfile ~/.ssh/id_ed25519",
|
||||
"identityfile ~/.ssh/id_ed25519_sk",
|
||||
"identityfile ~/.ssh/id_xmss",
|
||||
}, "\n")
|
||||
|
||||
// parseIdentityFilesForHost uses ssh -G to discern what SSH keys have
|
||||
// been enabled for the host (via the users SSH config) and returns a
|
||||
// list of existing identity files.
|
||||
//
|
||||
// We do this because when no keys are defined for a host, SSH uses
|
||||
// fallback keys (see above). However, by passing `-i` to attach our
|
||||
// private key, we're effectively disabling the fallback keys.
|
||||
//
|
||||
// Example invocation:
|
||||
//
|
||||
// ssh -G -o SendEnv=GIT_PROTOCOL git@github.com git-upload-pack 'coder/coder'
|
||||
//
|
||||
// The extra arguments work without issue and lets us run the command
|
||||
// as-is without stripping out the excess (git-upload-pack 'coder/coder').
|
||||
func parseIdentityFilesForHost(ctx context.Context, args, env []string) (identityFiles []string, error error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get user home dir failed: %w", err)
|
||||
}
|
||||
|
||||
var outBuf bytes.Buffer
|
||||
var r io.Reader = &outBuf
|
||||
|
||||
args = append([]string{"-G"}, args...)
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
cmd.Env = append(cmd.Env, env...)
|
||||
cmd.Stdout = &outBuf
|
||||
cmd.Stderr = io.Discard
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
// If ssh -G failed, the SSH version is likely too old, fallback
|
||||
// to using the default identity files.
|
||||
r = strings.NewReader(fallbackIdentityFiles)
|
||||
}
|
||||
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
if strings.HasPrefix(line, "identityfile ") {
|
||||
id := strings.TrimPrefix(line, "identityfile ")
|
||||
if strings.HasPrefix(id, "~/") {
|
||||
id = home + id[1:]
|
||||
}
|
||||
// OpenSSH on Windows is weird, it supports using (and does
|
||||
// use) mixed \ and / in paths.
|
||||
//
|
||||
// Example: C:\Users\ZeroCool/.ssh/known_hosts
|
||||
//
|
||||
// To check the file existence in Go, though, we want to use
|
||||
// proper Windows paths.
|
||||
// OpenSSH is amazing, this will work on Windows too:
|
||||
// C:\Users\ZeroCool/.ssh/id_rsa
|
||||
id = filepath.FromSlash(id)
|
||||
|
||||
// Only include the identity file if it exists.
|
||||
if _, err := os.Stat(id); err == nil {
|
||||
identityFiles = append(identityFiles, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := s.Err(); err != nil {
|
||||
// This should never happen, the check is for completeness.
|
||||
return nil, xerrors.Errorf("scan ssh output: %w", err)
|
||||
}
|
||||
|
||||
return identityFiles, nil
|
||||
}
|
||||
|
||||
+224
-79
@@ -2,8 +2,16 @@ package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -17,99 +25,236 @@ import (
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, string, gossh.PublicKey) {
|
||||
t.Helper()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer t.Cleanup(cancel) // Defer so that cancel is the first cleanup.
|
||||
|
||||
// get user public key
|
||||
keypair, err := client.GitSSHKey(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
//nolint:dogsled
|
||||
pubkey, _, _, _, err := gossh.ParseAuthorizedKey([]byte(keypair.PublicKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
// setup template
|
||||
agentToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "somename",
|
||||
Type: "someinstance",
|
||||
Agents: []*proto.Agent{{
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: agentToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// start workspace agent
|
||||
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
|
||||
agentClient := client
|
||||
clitest.SetupConfig(t, agentClient, root)
|
||||
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
t.Cleanup(func() { require.NoError(t, <-errC) })
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
return agentClient, agentToken, pubkey
|
||||
}
|
||||
|
||||
func serveSSHForGitSSH(t *testing.T, handler func(ssh.Session), pubkeys ...gossh.PublicKey) *net.TCPAddr {
|
||||
t.Helper()
|
||||
|
||||
// start ssh server
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = l.Close() })
|
||||
|
||||
serveOpts := []ssh.Option{
|
||||
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
for _, pubkey := range pubkeys {
|
||||
if ssh.KeysEqual(pubkey, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}),
|
||||
}
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
// as long as we get a successful session we don't care if the server errors
|
||||
errC <- ssh.Serve(l, handler, serveOpts...)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
_ = l.Close() // Ensure server shutdown.
|
||||
<-errC
|
||||
})
|
||||
|
||||
// start ssh session
|
||||
addr, ok := l.Addr().(*net.TCPAddr)
|
||||
require.True(t, ok)
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
func writePrivateKeyToFile(t *testing.T, name string, key *ecdsa.PrivateKey) {
|
||||
t.Helper()
|
||||
|
||||
b, err := x509.MarshalPKCS8PrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
b = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: b,
|
||||
})
|
||||
|
||||
err = os.WriteFile(name, b, 0o600)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGitSSH(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Dial", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// get user public key
|
||||
keypair, err := client.GitSSHKey(context.Background(), codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
publicKey, _, _, _, err := gossh.ParseAuthorizedKey([]byte(keypair.PublicKey))
|
||||
require.NoError(t, err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
// setup template
|
||||
agentToken := uuid.NewString()
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
ProvisionDryRun: echo.ProvisionComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "somename",
|
||||
Type: "someinstance",
|
||||
Agents: []*proto.Agent{{
|
||||
Auth: &proto.Agent_Token{
|
||||
Token: agentToken,
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// start workspace agent
|
||||
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false")
|
||||
agentClient := client
|
||||
clitest.SetupConfig(t, agentClient, root)
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
agentErrC := make(chan error)
|
||||
go func() {
|
||||
agentErrC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
dialer, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil)
|
||||
require.NoError(t, err)
|
||||
defer dialer.Close()
|
||||
_, err = dialer.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
// start ssh server
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
return ssh.KeysEqual(publicKey, key)
|
||||
})
|
||||
client, token, pubkey := prepareTestGitSSH(ctx, t)
|
||||
var inc int64
|
||||
sshErrC := make(chan error)
|
||||
go func() {
|
||||
// as long as we get a successful session we don't care if the server errors
|
||||
_ = ssh.Serve(l, func(s ssh.Session) {
|
||||
atomic.AddInt64(&inc, 1)
|
||||
t.Log("got authenticated session")
|
||||
sshErrC <- s.Exit(0)
|
||||
}, publicKeyOption)
|
||||
}()
|
||||
errC := make(chan error, 1)
|
||||
addr := serveSSHForGitSSH(t, func(s ssh.Session) {
|
||||
atomic.AddInt64(&inc, 1)
|
||||
t.Log("got authenticated session")
|
||||
select {
|
||||
case errC <- s.Exit(0):
|
||||
default:
|
||||
t.Error("error channel is full")
|
||||
}
|
||||
}, pubkey)
|
||||
|
||||
// start ssh session
|
||||
addr, ok := l.Addr().(*net.TCPAddr)
|
||||
require.True(t, ok)
|
||||
// set to agent config dir
|
||||
gitsshCmd, _ := clitest.New(t, "gitssh", "--agent-url", agentClient.URL.String(), "--agent-token", agentToken, "--", fmt.Sprintf("-p%d", addr.Port), "-o", "StrictHostKeyChecking=no", "-o", "IdentitiesOnly=yes", "127.0.0.1")
|
||||
err = gitsshCmd.ExecuteContext(context.Background())
|
||||
cmd, _ := clitest.New(t,
|
||||
"gitssh",
|
||||
"--agent-url", client.URL.String(),
|
||||
"--agent-token", token,
|
||||
"--",
|
||||
fmt.Sprintf("-p%d", addr.Port),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "IdentitiesOnly=yes",
|
||||
"127.0.0.1",
|
||||
)
|
||||
err := cmd.ExecuteContext(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, inc)
|
||||
|
||||
err = <-sshErrC
|
||||
require.NoError(t, err, "error in ssh session exit")
|
||||
|
||||
cancelFunc()
|
||||
err = <-agentErrC
|
||||
err = <-errC
|
||||
require.NoError(t, err, "error in agent execute")
|
||||
})
|
||||
|
||||
t.Run("Local SSH Keys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
home := t.TempDir()
|
||||
sshdir := filepath.Join(home, ".ssh")
|
||||
err := os.MkdirAll(sshdir, 0o700)
|
||||
require.NoError(t, err)
|
||||
|
||||
idFile := filepath.Join(sshdir, "id_ed25519")
|
||||
privkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
localPubkey, err := gossh.NewPublicKey(&privkey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
writePrivateKeyToFile(t, idFile, privkey)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
client, token, coderPubkey := prepareTestGitSSH(ctx, t)
|
||||
|
||||
authkey := make(chan gossh.PublicKey, 1)
|
||||
addr := serveSSHForGitSSH(t, func(s ssh.Session) {
|
||||
t.Logf("authenticated with: %s", gossh.MarshalAuthorizedKey(s.PublicKey()))
|
||||
select {
|
||||
case authkey <- s.PublicKey():
|
||||
default:
|
||||
t.Error("authkey channel is full")
|
||||
}
|
||||
}, localPubkey, coderPubkey)
|
||||
|
||||
// Create a new config which sets an identity file.
|
||||
config := filepath.Join(sshdir, "config")
|
||||
knownHosts := filepath.Join(sshdir, "known_hosts")
|
||||
err = os.WriteFile(config, []byte(strings.Join([]string{
|
||||
"Host mytest",
|
||||
" HostName 127.0.0.1",
|
||||
fmt.Sprintf(" Port %d", addr.Port),
|
||||
" StrictHostKeyChecking no",
|
||||
" UserKnownHostsFile=" + knownHosts,
|
||||
" IdentitiesOnly yes",
|
||||
" IdentityFile=" + idFile,
|
||||
}, "\n")), 0o600)
|
||||
require.NoError(t, err)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
cmdArgs := []string{
|
||||
"gitssh",
|
||||
"--agent-url", client.URL.String(),
|
||||
"--agent-token", token,
|
||||
"--",
|
||||
"-F", config,
|
||||
"mytest",
|
||||
}
|
||||
// Test authentication via local private key.
|
||||
cmd, _ := clitest.New(t, cmdArgs...)
|
||||
cmd.SetOut(pty.Output())
|
||||
cmd.SetErr(pty.Output())
|
||||
err = cmd.ExecuteContext(ctx)
|
||||
require.NoError(t, err)
|
||||
select {
|
||||
case key := <-authkey:
|
||||
require.Equal(t, localPubkey, key)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for auth")
|
||||
}
|
||||
|
||||
// Delete the local private key.
|
||||
err = os.Remove(idFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// With the local file deleted, the coder key should be used.
|
||||
cmd, _ = clitest.New(t, cmdArgs...)
|
||||
cmd.SetOut(pty.Output())
|
||||
cmd.SetErr(pty.Output())
|
||||
err = cmd.ExecuteContext(ctx)
|
||||
require.NoError(t, err)
|
||||
select {
|
||||
case key := <-authkey:
|
||||
require.Equal(t, coderPubkey, key)
|
||||
case <-ctx.Done():
|
||||
t.Fatal("timeout waiting for auth")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+26
-8
@@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -58,15 +59,17 @@ func workspaceListRowFromWorkspace(now time.Time, usersByID map[uuid.UUID]coders
|
||||
|
||||
func list() *cobra.Command {
|
||||
var (
|
||||
all bool
|
||||
columns []string
|
||||
defaultQuery = "owner:me"
|
||||
searchQuery string
|
||||
all bool
|
||||
columns []string
|
||||
defaultQuery = "owner:me"
|
||||
searchQuery string
|
||||
me bool
|
||||
displayWorkspaces []workspaceListRow
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "list",
|
||||
Short: "List all workspaces",
|
||||
Short: "List workspaces",
|
||||
Aliases: []string{"ls"},
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
@@ -80,6 +83,14 @@ func list() *cobra.Command {
|
||||
if all && searchQuery == defaultQuery {
|
||||
filter.FilterQuery = ""
|
||||
}
|
||||
|
||||
if me {
|
||||
myUser, err := client.User(cmd.Context(), codersdk.Me)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filter.Owner = myUser.Username
|
||||
}
|
||||
workspaces, err := client.Workspaces(cmd.Context(), filter)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -101,7 +112,7 @@ func list() *cobra.Command {
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
displayWorkspaces := make([]workspaceListRow, len(workspaces))
|
||||
displayWorkspaces = make([]workspaceListRow, len(workspaces))
|
||||
for i, workspace := range workspaces {
|
||||
displayWorkspaces[i] = workspaceListRowFromWorkspace(now, usersByID, workspace)
|
||||
}
|
||||
@@ -115,10 +126,17 @@ func list() *cobra.Command {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
availColumns, err := cliui.TableHeaders(displayWorkspaces)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
columnString := strings.Join(availColumns[:], ", ")
|
||||
|
||||
cmd.Flags().BoolVarP(&all, "all", "a", false,
|
||||
"Specifies whether all workspaces will be listed or not.")
|
||||
cmd.Flags().StringArrayVarP(&columns, "column", "c", nil,
|
||||
"Specify a column to filter in the table.")
|
||||
cmd.Flags().StringVar(&searchQuery, "search", defaultQuery, "Search for a workspace with a query.")
|
||||
fmt.Sprintf("Specify a column to filter in the table. Available columns are: %v", columnString))
|
||||
cmd.Flags().StringVar(&searchQuery, "search", "", "Search for a workspace with a query.")
|
||||
return cmd
|
||||
}
|
||||
|
||||
+9
-6
@@ -45,7 +45,7 @@ func login() *cobra.Command {
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "login <url>",
|
||||
Short: "Authenticate with a Coder deployment",
|
||||
Short: "Authenticate with Coder deployment",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
rawURL := args[0]
|
||||
@@ -66,7 +66,10 @@ func login() *cobra.Command {
|
||||
serverURL.Scheme = "https"
|
||||
}
|
||||
|
||||
client := codersdk.New(serverURL)
|
||||
client, err := createUnauthenticatedClient(cmd, serverURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to check the version of the server prior to logging in.
|
||||
// It may be useful to warn the user if they are trying to login
|
||||
@@ -80,7 +83,7 @@ func login() *cobra.Command {
|
||||
|
||||
hasInitialUser, err := client.HasFirstUser(cmd.Context())
|
||||
if err != nil {
|
||||
return xerrors.Errorf("has initial user: %w", err)
|
||||
return xerrors.Errorf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser? Error - has initial user: %w", serverURL.String(), err)
|
||||
}
|
||||
if !hasInitialUser {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), caret+"Your Coder deployment hasn't been set up!\n")
|
||||
@@ -245,9 +248,9 @@ func login() *cobra.Command {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
cliflag.StringVarP(cmd.Flags(), &email, "email", "e", "CODER_EMAIL", "", "Specifies an email address to authenticate with.")
|
||||
cliflag.StringVarP(cmd.Flags(), &username, "username", "u", "CODER_USERNAME", "", "Specifies a username to authenticate with.")
|
||||
cliflag.StringVarP(cmd.Flags(), &password, "password", "p", "CODER_PASSWORD", "", "Specifies a password to authenticate with.")
|
||||
cliflag.StringVarP(cmd.Flags(), &email, "first-user-email", "", "CODER_FIRST_USER_EMAIL", "", "Specifies an email address to use if creating the first user for the deployment.")
|
||||
cliflag.StringVarP(cmd.Flags(), &username, "first-user-username", "", "CODER_FIRST_USER_USERNAME", "", "Specifies a username to use if creating the first user for the deployment.")
|
||||
cliflag.StringVarP(cmd.Flags(), &password, "first-user-password", "", "CODER_FIRST_USER_PASSWORD", "", "Specifies a password to use if creating the first user for the deployment.")
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
||||
+11
-1
@@ -2,6 +2,7 @@ package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -23,6 +24,15 @@ func TestLogin(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("InitialUserBadLoginURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
badLoginURL := "https://fcca2077f06e68aaf9"
|
||||
root, _ := clitest.New(t, "login", badLoginURL)
|
||||
err := root.Execute()
|
||||
errMsg := fmt.Sprintf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser?", badLoginURL)
|
||||
require.ErrorContains(t, err, errMsg)
|
||||
})
|
||||
|
||||
t.Run("InitialUserTTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
@@ -64,7 +74,7 @@ func TestLogin(t *testing.T) {
|
||||
// accurately detect Windows ptys when they are not attached to a process:
|
||||
// https://github.com/mattn/go-isatty/issues/59
|
||||
doneChan := make(chan struct{})
|
||||
root, _ := clitest.New(t, "login", client.URL.String(), "--username", "testuser", "--email", "user@coder.com", "--password", "password")
|
||||
root, _ := clitest.New(t, "login", client.URL.String(), "--first-user-username", "testuser", "--first-user-email", "user@coder.com", "--first-user-password", "password")
|
||||
pty := ptytest.New(t)
|
||||
root.SetIn(pty.Input())
|
||||
root.SetOut(pty.Output())
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@ import (
|
||||
func logout() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "logout",
|
||||
Short: "Remove the local authenticated session",
|
||||
Short: "Unauthenticate your local session",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
if err != nil {
|
||||
|
||||
@@ -20,6 +20,9 @@ func parameters() *cobra.Command {
|
||||
// constructing curl requests.
|
||||
Hidden: true,
|
||||
Aliases: []string{"params"},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmd.Help()
|
||||
},
|
||||
}
|
||||
cmd.AddCommand(
|
||||
parameterList(),
|
||||
|
||||
+120
-127
@@ -6,32 +6,30 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/pion/udp"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func portForward() *cobra.Command {
|
||||
var (
|
||||
tcpForwards []string // <port>:<port>
|
||||
udpForwards []string // <port>:<port>
|
||||
unixForwards []string // <path>:<path> OR <port>:<path>
|
||||
wireguard bool
|
||||
tcpForwards []string // <port>:<port>
|
||||
udpForwards []string // <port>:<port>
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "port-forward <workspace>",
|
||||
Short: "Forward one or more ports from the local machine to the remote workspace",
|
||||
Short: "Forward ports from machine to a workspace",
|
||||
Aliases: []string{"tunnel"},
|
||||
Args: cobra.ExactArgs(1),
|
||||
Example: formatExamples(
|
||||
@@ -43,24 +41,20 @@ func portForward() *cobra.Command {
|
||||
Description: "Port forward a single UDP port from port 9000 to port 9000 on your local machine",
|
||||
Command: "coder port-forward <workspace> --udp 9000",
|
||||
},
|
||||
example{
|
||||
Description: "Forward a Unix socket in the workspace to a local Unix socket",
|
||||
Command: "coder port-forward <workspace> --unix ./local.sock:~/remote.sock",
|
||||
},
|
||||
example{
|
||||
Description: "Forward a Unix socket in the workspace to a local TCP port",
|
||||
Command: "coder port-forward <workspace> --unix 8080:~/remote.sock",
|
||||
},
|
||||
example{
|
||||
Description: "Port forward multiple TCP ports and a UDP port",
|
||||
Command: "coder port-forward <workspace> --tcp 8080:8080 --tcp 9000:3000 --udp 5353:53",
|
||||
},
|
||||
example{
|
||||
Description: "Port forward multiple ports (TCP or UDP) in condensed syntax",
|
||||
Command: "coder port-forward <workspace> --tcp 8080,9000:3000,9090-9092,10000-10002:10010-10012",
|
||||
},
|
||||
),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
specs, err := parsePortForwards(tcpForwards, udpForwards, unixForwards)
|
||||
specs, err := parsePortForwards(tcpForwards, udpForwards)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse port-forward specs: %w", err)
|
||||
}
|
||||
@@ -101,12 +95,7 @@ func portForward() *cobra.Command {
|
||||
return xerrors.Errorf("await agent: %w", err)
|
||||
}
|
||||
|
||||
var conn agent.Conn
|
||||
if !wireguard {
|
||||
conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
|
||||
} else {
|
||||
conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
|
||||
}
|
||||
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -149,29 +138,41 @@ func portForward() *cobra.Command {
|
||||
case <-ctx.Done():
|
||||
closeErr = ctx.Err()
|
||||
case <-sigs:
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Received signal, closing all listeners and active connections")
|
||||
closeErr = xerrors.New("signal received")
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "\nReceived signal, closing all listeners and active connections")
|
||||
}
|
||||
|
||||
cancel()
|
||||
closeAllListeners()
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
_, err = conn.Ping(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
ticker.Stop()
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!")
|
||||
wg.Wait()
|
||||
return closeErr
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringArrayVarP(&tcpForwards, "tcp", "p", []string{}, "Forward a TCP port from the workspace to the local machine")
|
||||
cmd.Flags().StringArrayVar(&udpForwards, "udp", []string{}, "Forward a UDP port from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols")
|
||||
cmd.Flags().StringArrayVar(&unixForwards, "unix", []string{}, "Forward a Unix socket in the workspace to a local Unix socket or TCP port")
|
||||
cmd.Flags().BoolVarP(&wireguard, "wireguard", "", false, "Specifies whether to use wireguard networking or not.")
|
||||
_ = cmd.Flags().MarkHidden("wireguard")
|
||||
cliflag.StringArrayVarP(cmd.Flags(), &tcpForwards, "tcp", "p", "CODER_PORT_FORWARD_TCP", nil, "Forward TCP port(s) from the workspace to the local machine")
|
||||
cliflag.StringArrayVarP(cmd.Flags(), &udpForwards, "udp", "", "CODER_PORT_FORWARD_UDP", nil, "Forward UDP port(s) from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn agent.Conn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
|
||||
func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *codersdk.AgentConn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress)
|
||||
|
||||
var (
|
||||
@@ -198,8 +199,6 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn agent.Co
|
||||
IP: net.ParseIP(host),
|
||||
Port: portInt,
|
||||
})
|
||||
case "unix":
|
||||
l, err = net.Listen(spec.listenNetwork, spec.listenAddress)
|
||||
default:
|
||||
return nil, xerrors.Errorf("unknown listen network %q", spec.listenNetwork)
|
||||
}
|
||||
@@ -213,7 +212,11 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn agent.Co
|
||||
for {
|
||||
netConn, err := l.Accept()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %+v\n", spec.listenNetwork, spec.listenAddress, err)
|
||||
// Silently ignore net.ErrClosed errors.
|
||||
if xerrors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err)
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener")
|
||||
return
|
||||
}
|
||||
@@ -236,65 +239,50 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn agent.Co
|
||||
}
|
||||
|
||||
type portForwardSpec struct {
|
||||
listenNetwork string // tcp, udp, unix
|
||||
listenNetwork string // tcp, udp
|
||||
listenAddress string // <ip>:<port> or path
|
||||
|
||||
dialNetwork string // tcp, udp, unix
|
||||
dialNetwork string // tcp, udp
|
||||
dialAddress string // <ip>:<port> or path
|
||||
}
|
||||
|
||||
func parsePortForwards(tcpSpecs, udpSpecs, unixSpecs []string) ([]portForwardSpec, error) {
|
||||
func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) {
|
||||
specs := []portForwardSpec{}
|
||||
|
||||
for _, spec := range tcpSpecs {
|
||||
local, remote, err := parsePortPort(spec)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
|
||||
}
|
||||
|
||||
specs = append(specs, portForwardSpec{
|
||||
listenNetwork: "tcp",
|
||||
listenAddress: fmt.Sprintf("127.0.0.1:%v", local),
|
||||
dialNetwork: "tcp",
|
||||
dialAddress: fmt.Sprintf("127.0.0.1:%v", remote),
|
||||
})
|
||||
}
|
||||
|
||||
for _, spec := range udpSpecs {
|
||||
local, remote, err := parsePortPort(spec)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
|
||||
}
|
||||
|
||||
specs = append(specs, portForwardSpec{
|
||||
listenNetwork: "udp",
|
||||
listenAddress: fmt.Sprintf("127.0.0.1:%v", local),
|
||||
dialNetwork: "udp",
|
||||
dialAddress: fmt.Sprintf("127.0.0.1:%v", remote),
|
||||
})
|
||||
}
|
||||
|
||||
for _, specStr := range unixSpecs {
|
||||
localPath, localTCP, remotePath, err := parseUnixUnix(specStr)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse Unix port-forward specification %q: %w", specStr, err)
|
||||
}
|
||||
|
||||
spec := portForwardSpec{
|
||||
dialNetwork: "unix",
|
||||
dialAddress: remotePath,
|
||||
}
|
||||
if localPath == "" {
|
||||
spec.listenNetwork = "tcp"
|
||||
spec.listenAddress = fmt.Sprintf("127.0.0.1:%v", localTCP)
|
||||
} else {
|
||||
if runtime.GOOS == "windows" {
|
||||
return nil, xerrors.Errorf("Unix port-forwarding is not supported on Windows")
|
||||
for _, specEntry := range tcpSpecs {
|
||||
for _, spec := range strings.Split(specEntry, ",") {
|
||||
ports, err := parseSrcDestPorts(spec)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err)
|
||||
}
|
||||
|
||||
for _, port := range ports {
|
||||
specs = append(specs, portForwardSpec{
|
||||
listenNetwork: "tcp",
|
||||
listenAddress: fmt.Sprintf("127.0.0.1:%v", port.local),
|
||||
dialNetwork: "tcp",
|
||||
dialAddress: fmt.Sprintf("127.0.0.1:%v", port.remote),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, specEntry := range udpSpecs {
|
||||
for _, spec := range strings.Split(specEntry, ",") {
|
||||
ports, err := parseSrcDestPorts(spec)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err)
|
||||
}
|
||||
|
||||
for _, port := range ports {
|
||||
specs = append(specs, portForwardSpec{
|
||||
listenNetwork: "udp",
|
||||
listenAddress: fmt.Sprintf("127.0.0.1:%v", port.local),
|
||||
dialNetwork: "udp",
|
||||
dialAddress: fmt.Sprintf("127.0.0.1:%v", port.remote),
|
||||
})
|
||||
}
|
||||
spec.listenNetwork = "unix"
|
||||
spec.listenAddress = localPath
|
||||
}
|
||||
specs = append(specs, spec)
|
||||
}
|
||||
|
||||
// Check for duplicate entries.
|
||||
@@ -322,67 +310,72 @@ func parsePort(in string) (uint16, error) {
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
func parseUnixPath(in string) (string, error) {
|
||||
path, err := agent.ExpandRelativeHomePath(strings.TrimSpace(in))
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("tidy path %q: %w", in, err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
type parsedSrcDestPort struct {
|
||||
local, remote uint16
|
||||
}
|
||||
|
||||
func parsePortPort(in string) (local uint16, remote uint16, err error) {
|
||||
func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
|
||||
parts := strings.Split(in, ":")
|
||||
if len(parts) > 2 {
|
||||
return 0, 0, xerrors.Errorf("invalid port specification %q", in)
|
||||
return nil, xerrors.Errorf("invalid port specification %q", in)
|
||||
}
|
||||
if len(parts) == 1 {
|
||||
// Duplicate the single part
|
||||
parts = append(parts, parts[0])
|
||||
}
|
||||
if !strings.Contains(parts[0], "-") {
|
||||
local, err := parsePort(parts[0])
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse local port from %q: %w", in, err)
|
||||
}
|
||||
remote, err := parsePort(parts[1])
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse remote port from %q: %w", in, err)
|
||||
}
|
||||
|
||||
local, err = parsePort(parts[0])
|
||||
if err != nil {
|
||||
return 0, 0, xerrors.Errorf("parse local port from %q: %w", in, err)
|
||||
}
|
||||
remote, err = parsePort(parts[1])
|
||||
if err != nil {
|
||||
return 0, 0, xerrors.Errorf("parse remote port from %q: %w", in, err)
|
||||
return []parsedSrcDestPort{{local: local, remote: remote}}, nil
|
||||
}
|
||||
|
||||
return local, remote, nil
|
||||
local, err := parsePortRange(parts[0])
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse local port range from %q: %w", in, err)
|
||||
}
|
||||
remote, err := parsePortRange(parts[1])
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse remote port range from %q: %w", in, err)
|
||||
}
|
||||
if len(local) != len(remote) {
|
||||
return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote))
|
||||
}
|
||||
var out []parsedSrcDestPort
|
||||
for i := range local {
|
||||
out = append(out, parsedSrcDestPort{
|
||||
local: local[i],
|
||||
remote: remote[i],
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func parsePortOrUnixPath(in string) (string, uint16, error) {
|
||||
port, err := parsePort(in)
|
||||
if err == nil {
|
||||
return "", port, nil
|
||||
func parsePortRange(in string) ([]uint16, error) {
|
||||
parts := strings.Split(in, "-")
|
||||
if len(parts) != 2 {
|
||||
return nil, xerrors.Errorf("invalid port range specification %q", in)
|
||||
}
|
||||
|
||||
path, err := parseUnixPath(in)
|
||||
start, err := parsePort(parts[0])
|
||||
if err != nil {
|
||||
return "", 0, xerrors.Errorf("could not parse port or unix path %q: %w", in, err)
|
||||
return nil, xerrors.Errorf("parse range start port from %q: %w", in, err)
|
||||
}
|
||||
|
||||
return path, 0, nil
|
||||
}
|
||||
|
||||
func parseUnixUnix(in string) (string, uint16, string, error) {
|
||||
parts := strings.Split(in, ":")
|
||||
if len(parts) > 2 {
|
||||
return "", 0, "", xerrors.Errorf("invalid port-forward specification %q", in)
|
||||
}
|
||||
if len(parts) == 1 {
|
||||
// Duplicate the single part
|
||||
parts = append(parts, parts[0])
|
||||
}
|
||||
|
||||
localPath, localPort, err := parsePortOrUnixPath(parts[0])
|
||||
end, err := parsePort(parts[1])
|
||||
if err != nil {
|
||||
return "", 0, "", xerrors.Errorf("parse local part of spec %q: %w", in, err)
|
||||
return nil, xerrors.Errorf("parse range end port from %q: %w", in, err)
|
||||
}
|
||||
|
||||
// We don't really touch the remote path at all since it gets cleaned
|
||||
// up/expanded on the remote.
|
||||
return localPath, localPort, parts[1], nil
|
||||
if end < start {
|
||||
return nil, xerrors.Errorf("range end port %v is less than start port %v", end, start)
|
||||
}
|
||||
var ports []uint16
|
||||
for i := start; i <= end; i++ {
|
||||
ports = append(ports, i)
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_parsePortForwards(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
portForwardSpecToString := func(v []portForwardSpec) (out []string) {
|
||||
for _, p := range v {
|
||||
require.Equal(t, p.listenNetwork, p.dialNetwork)
|
||||
out = append(out, fmt.Sprintf("%s:%s", strings.Replace(p.listenAddress, "127.0.0.1:", "", 1), strings.Replace(p.dialAddress, "127.0.0.1:", "", 1)))
|
||||
}
|
||||
return out
|
||||
}
|
||||
type args struct {
|
||||
tcpSpecs []string
|
||||
udpSpecs []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "TCP mixed ports and ranges",
|
||||
args: args{
|
||||
tcpSpecs: []string{
|
||||
"8000,8080:8081,9000-9002,9003-9004:9005-9006",
|
||||
"10000",
|
||||
},
|
||||
},
|
||||
want: []string{
|
||||
"8000:8000",
|
||||
"8080:8081",
|
||||
"9000:9000",
|
||||
"9001:9001",
|
||||
"9002:9002",
|
||||
"9003:9005",
|
||||
"9004:9006",
|
||||
"10000:10000",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UDP with port range",
|
||||
args: args{
|
||||
udpSpecs: []string{"8000,8080-8081"},
|
||||
},
|
||||
want: []string{
|
||||
"8000:8000",
|
||||
"8080:8080",
|
||||
"8081:8081",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Bad port range",
|
||||
args: args{
|
||||
tcpSpecs: []string{"8000-7000"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Bad dest port range",
|
||||
args: args{
|
||||
tcpSpecs: []string{"8080-8081:9080-9082"},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := parsePortForwards(tt.args.tcpSpecs, tt.args.udpSpecs)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("parsePortForwards() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
gotStrings := portForwardSpecToString(got)
|
||||
require.Equal(t, tt.want, gotStrings)
|
||||
})
|
||||
}
|
||||
}
|
||||
+43
-169
@@ -1,17 +1,12 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/udp"
|
||||
@@ -23,11 +18,13 @@ import (
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestPortForward(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("These tests flake... a lot. It seems related to the Tailscale change, but all other tests pass...")
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -37,15 +34,17 @@ func TestPortForward(t *testing.T) {
|
||||
|
||||
cmd, root := clitest.New(t, "port-forward", "blah")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(buf)
|
||||
pty := ptytest.New(t)
|
||||
cmd.SetIn(pty.Input())
|
||||
cmd.SetOut(pty.Output())
|
||||
cmd.SetErr(pty.Output())
|
||||
|
||||
err := cmd.Execute()
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "no port-forwards")
|
||||
|
||||
// Check that the help was printed.
|
||||
require.Contains(t, buf.String(), "port-forward <workspace>")
|
||||
pty.ExpectMatch("port-forward <workspace>")
|
||||
})
|
||||
|
||||
cases := []struct {
|
||||
@@ -58,7 +57,7 @@ func TestPortForward(t *testing.T) {
|
||||
// setupRemote creates a "remote" listener to emulate a service in the
|
||||
// workspace.
|
||||
setupRemote func(t *testing.T) net.Listener
|
||||
// setupLocal returns an available port or Unix socket path that the
|
||||
// setupLocal returns an available port that the
|
||||
// port-forward command will listen on "locally". Returns the address
|
||||
// you pass to net.Dial, and the port/path you pass to `coder
|
||||
// port-forward`.
|
||||
@@ -110,34 +109,14 @@ func TestPortForward(t *testing.T) {
|
||||
return l.Addr().String(), port
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Unix",
|
||||
network: "unix",
|
||||
flag: "--unix=%v:%v",
|
||||
setupRemote: func(t *testing.T) net.Listener {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix socket forwarding isn't supported on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.NoError(t, err, "create UDP listener")
|
||||
return l
|
||||
},
|
||||
setupLocal: func(t *testing.T) (string, string) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "test.sock")
|
||||
return path, path
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Setup agent once to be shared between test-cases (avoid expensive
|
||||
// non-parallel setup).
|
||||
var (
|
||||
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user = coderdtest.CreateFirstUser(t, client)
|
||||
_, workspace = runAgent(t, client, user.UserID)
|
||||
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user = coderdtest.CreateFirstUser(t, client)
|
||||
workspace = runAgent(t, client, user.UserID)
|
||||
)
|
||||
|
||||
for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter
|
||||
@@ -155,17 +134,19 @@ func TestPortForward(t *testing.T) {
|
||||
|
||||
// Launch port-forward in a goroutine so we can start dialing
|
||||
// the "local" listener.
|
||||
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag)
|
||||
cmd, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(buf)
|
||||
pty := ptytest.New(t)
|
||||
cmd.SetIn(pty.Input())
|
||||
cmd.SetOut(pty.Output())
|
||||
cmd.SetErr(pty.Output())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
waitForPortForwardReady(t, buf)
|
||||
pty.ExpectMatch("Ready!")
|
||||
|
||||
t.Parallel() // Port is reserved, enable parallel execution.
|
||||
|
||||
@@ -201,17 +182,19 @@ func TestPortForward(t *testing.T) {
|
||||
|
||||
// Launch port-forward in a goroutine so we can start dialing
|
||||
// the "local" listeners.
|
||||
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag1, flag2)
|
||||
cmd, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag1, flag2)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(buf)
|
||||
pty := ptytest.New(t)
|
||||
cmd.SetIn(pty.Input())
|
||||
cmd.SetOut(pty.Output())
|
||||
cmd.SetErr(pty.Output())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
waitForPortForwardReady(t, buf)
|
||||
pty.ExpectMatch("Ready!")
|
||||
|
||||
t.Parallel() // Port is reserved, enable parallel execution.
|
||||
|
||||
@@ -234,74 +217,16 @@ func TestPortForward(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// Test doing a TCP -> Unix forward.
|
||||
//nolint:paralleltest
|
||||
t.Run("TCP2Unix", func(t *testing.T) {
|
||||
var (
|
||||
// Find the TCP and Unix cases so we can use their setupLocal and
|
||||
// setupRemote methods respectively.
|
||||
tcpCase = cases[0]
|
||||
unixCase = cases[2]
|
||||
|
||||
// Setup remote Unix listener.
|
||||
p1 = setupTestListener(t, unixCase.setupRemote(t))
|
||||
)
|
||||
|
||||
// Create a flag that forwards from local TCP to Unix listener 1.
|
||||
// Notably this is a --unix flag.
|
||||
localAddress, localFlag := tcpCase.setupLocal(t)
|
||||
flag := fmt.Sprintf(unixCase.flag, localFlag, p1)
|
||||
|
||||
// Launch port-forward in a goroutine so we can start dialing
|
||||
// the "local" listener.
|
||||
cmd, root := clitest.New(t, "port-forward", workspace.Name, flag)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(buf)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
waitForPortForwardReady(t, buf)
|
||||
|
||||
t.Parallel() // Port is reserved, enable parallel execution.
|
||||
|
||||
// Open two connections simultaneously and test them out of
|
||||
// sync.
|
||||
d := net.Dialer{Timeout: testutil.WaitShort}
|
||||
c1, err := d.DialContext(ctx, tcpCase.network, localAddress)
|
||||
require.NoError(t, err, "open connection 1 to 'local' listener")
|
||||
defer c1.Close()
|
||||
c2, err := d.DialContext(ctx, tcpCase.network, localAddress)
|
||||
require.NoError(t, err, "open connection 2 to 'local' listener")
|
||||
defer c2.Close()
|
||||
testDial(t, c2)
|
||||
testDial(t, c1)
|
||||
|
||||
cancel()
|
||||
err = <-errC
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
})
|
||||
|
||||
// Test doing TCP, UDP and Unix at the same time.
|
||||
// Test doing TCP and UDP at the same time.
|
||||
//nolint:paralleltest
|
||||
t.Run("All", func(t *testing.T) {
|
||||
var (
|
||||
// These aren't fixed size because we exclude Unix on Windows.
|
||||
dials = []addr{}
|
||||
flags = []string{}
|
||||
)
|
||||
|
||||
// Start listeners and populate arrays with the cases.
|
||||
for _, c := range cases {
|
||||
if strings.HasPrefix(c.network, "unix") && runtime.GOOS == "windows" {
|
||||
// Unix isn't supported on Windows, but we can still
|
||||
// test other protocols together.
|
||||
continue
|
||||
}
|
||||
|
||||
p := setupTestListener(t, c.setupRemote(t))
|
||||
|
||||
localAddress, localFlag := c.setupLocal(t)
|
||||
@@ -314,17 +239,19 @@ func TestPortForward(t *testing.T) {
|
||||
|
||||
// Launch port-forward in a goroutine so we can start dialing
|
||||
// the "local" listeners.
|
||||
cmd, root := clitest.New(t, append([]string{"port-forward", workspace.Name}, flags...)...)
|
||||
cmd, root := clitest.New(t, append([]string{"-v", "port-forward", workspace.Name}, flags...)...)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := newThreadSafeBuffer()
|
||||
cmd.SetOut(buf)
|
||||
pty := ptytest.New(t)
|
||||
cmd.SetIn(pty.Input())
|
||||
cmd.SetOut(pty.Output())
|
||||
cmd.SetErr(pty.Output())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
errC <- cmd.ExecuteContext(ctx)
|
||||
}()
|
||||
waitForPortForwardReady(t, buf)
|
||||
pty.ExpectMatch("Ready!")
|
||||
|
||||
t.Parallel() // Port is reserved, enable parallel execution.
|
||||
|
||||
@@ -355,7 +282,8 @@ func TestPortForward(t *testing.T) {
|
||||
|
||||
// runAgent creates a fake workspace and starts an agent locally for that
|
||||
// workspace. The agent will be cleaned up on test completion.
|
||||
func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]codersdk.WorkspaceResource, codersdk.Workspace) {
|
||||
// nolint:unused
|
||||
func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) codersdk.Workspace {
|
||||
ctx := context.Background()
|
||||
user, err := client.User(ctx, userID.String())
|
||||
require.NoError(t, err, "specified user does not exist")
|
||||
@@ -391,8 +319,12 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
|
||||
// Start workspace agent in a goroutine
|
||||
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false")
|
||||
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
|
||||
clitest.SetupConfig(t, client, root)
|
||||
pty := ptytest.New(t)
|
||||
cmd.SetIn(pty.Input())
|
||||
cmd.SetOut(pty.Output())
|
||||
cmd.SetErr(pty.Output())
|
||||
errC := make(chan error)
|
||||
agentCtx, agentCancel := context.WithCancel(ctx)
|
||||
t.Cleanup(func() {
|
||||
@@ -404,15 +336,13 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders
|
||||
errC <- cmd.ExecuteContext(agentCtx)
|
||||
}()
|
||||
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID)
|
||||
require.NoError(t, err)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
|
||||
return resources, workspace
|
||||
return workspace
|
||||
}
|
||||
|
||||
// setupTestListener starts accepting connections and echoing a single packet.
|
||||
// Returns the listener and the listen port or Unix path.
|
||||
// Returns the listener and the listen port.
|
||||
func setupTestListener(t *testing.T, l net.Listener) string {
|
||||
t.Helper()
|
||||
|
||||
@@ -444,11 +374,9 @@ func setupTestListener(t *testing.T, l net.Listener) string {
|
||||
}()
|
||||
|
||||
addr := l.Addr().String()
|
||||
if !strings.HasPrefix(l.Addr().Network(), "unix") {
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
|
||||
addr = port
|
||||
}
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
require.NoErrorf(t, err, "split non-Unix listen path %q", addr)
|
||||
addr = port
|
||||
|
||||
return addr
|
||||
}
|
||||
@@ -486,61 +414,7 @@ func assertWritePayload(t *testing.T, w io.Writer, payload []byte) {
|
||||
assert.Equal(t, len(payload), n, "payload length does not match")
|
||||
}
|
||||
|
||||
func waitForPortForwardReady(t *testing.T, output *threadSafeBuffer) {
|
||||
t.Helper()
|
||||
for i := 0; i < 100; i++ {
|
||||
time.Sleep(testutil.IntervalMedium)
|
||||
|
||||
data := output.String()
|
||||
if strings.Contains(data, "Ready!") {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatal("port-forward command did not become ready in time")
|
||||
}
|
||||
|
||||
type addr struct {
|
||||
network string
|
||||
addr string
|
||||
}
|
||||
|
||||
type threadSafeBuffer struct {
|
||||
b *bytes.Buffer
|
||||
mut *sync.RWMutex
|
||||
}
|
||||
|
||||
func newThreadSafeBuffer() *threadSafeBuffer {
|
||||
return &threadSafeBuffer{
|
||||
b: bytes.NewBuffer(nil),
|
||||
mut: new(sync.RWMutex),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ io.Reader = &threadSafeBuffer{}
|
||||
_ io.Writer = &threadSafeBuffer{}
|
||||
)
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (b *threadSafeBuffer) Read(p []byte) (int, error) {
|
||||
b.mut.RLock()
|
||||
defer b.mut.RUnlock()
|
||||
|
||||
return b.b.Read(p)
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (b *threadSafeBuffer) Write(p []byte) (int, error) {
|
||||
b.mut.Lock()
|
||||
defer b.mut.Unlock()
|
||||
|
||||
return b.b.Write(p)
|
||||
}
|
||||
|
||||
func (b *threadSafeBuffer) String() string {
|
||||
b.mut.RLock()
|
||||
defer b.mut.RUnlock()
|
||||
|
||||
return b.b.String()
|
||||
}
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ func publickey() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "publickey",
|
||||
Aliases: []string{"pubkey"},
|
||||
Short: "Output your public key for Git operations",
|
||||
Short: "Output your Coder public key used for Git operations",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/migrations"
|
||||
"github.com/coder/coder/coderd/userpassword"
|
||||
)
|
||||
|
||||
@@ -20,7 +21,7 @@ func resetPassword() *cobra.Command {
|
||||
|
||||
root := &cobra.Command{
|
||||
Use: "reset-password <username>",
|
||||
Short: "Reset a user's password by directly updating the database",
|
||||
Short: "Directly connect to the database to reset a user's password",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
username := args[0]
|
||||
@@ -35,7 +36,7 @@ func resetPassword() *cobra.Command {
|
||||
return xerrors.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
|
||||
err = database.EnsureClean(sqlDB)
|
||||
err = migrations.EnsureClean(sqlDB)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("database needs migration: %w", err)
|
||||
}
|
||||
|
||||
@@ -41,13 +41,14 @@ func TestResetPassword(t *testing.T) {
|
||||
serverCmd, cfg := clitest.New(t,
|
||||
"server",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--postgres-url", connectionURL,
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
err = serverCmd.ExecuteContext(ctx)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
var rawURL string
|
||||
require.Eventually(t, func() bool {
|
||||
|
||||
+105
-14
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/cli/config"
|
||||
"github.com/coder/coder/cli/deployment"
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
@@ -41,20 +44,24 @@ const (
|
||||
varAgentToken = "agent-token"
|
||||
varAgentURL = "agent-url"
|
||||
varGlobalConfig = "global-config"
|
||||
varHeader = "header"
|
||||
varNoOpen = "no-open"
|
||||
varNoVersionCheck = "no-version-warning"
|
||||
varNoFeatureWarning = "no-feature-warning"
|
||||
varForceTty = "force-tty"
|
||||
varVerbose = "verbose"
|
||||
varExperimental = "experimental"
|
||||
notLoggedInMessage = "You are not logged in. Try logging in using 'coder login <url>'."
|
||||
|
||||
envNoVersionCheck = "CODER_NO_VERSION_WARNING"
|
||||
envNoFeatureWarning = "CODER_NO_FEATURE_WARNING"
|
||||
envExperimental = "CODER_EXPERIMENTAL"
|
||||
envSessionToken = "CODER_SESSION_TOKEN"
|
||||
envURL = "CODER_URL"
|
||||
)
|
||||
|
||||
var (
|
||||
errUnauthenticated = xerrors.New(notLoggedInMessage)
|
||||
envSessionToken = "CODER_SESSION_TOKEN"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -89,12 +96,15 @@ func Core() []*cobra.Command {
|
||||
users(),
|
||||
versionCmd(),
|
||||
workspaceAgent(),
|
||||
features(),
|
||||
tokens(),
|
||||
}
|
||||
}
|
||||
|
||||
func AGPL() []*cobra.Command {
|
||||
all := append(Core(), Server(coderd.New))
|
||||
all := append(Core(), Server(deployment.Flags(), func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) {
|
||||
api := coderd.New(o)
|
||||
return api, api, nil
|
||||
}))
|
||||
return all
|
||||
}
|
||||
|
||||
@@ -103,7 +113,7 @@ func Root(subcommands []*cobra.Command) *cobra.Command {
|
||||
Use: "coder",
|
||||
SilenceErrors: true,
|
||||
SilenceUsage: true,
|
||||
Long: `Coder — A tool for provisioning self-hosted development environments.
|
||||
Long: `Coder — A tool for provisioning self-hosted development environments with Terraform.
|
||||
`,
|
||||
PersistentPreRun: func(cmd *cobra.Command, args []string) {
|
||||
if cliflag.IsSetBool(cmd, varNoVersionCheck) &&
|
||||
@@ -162,27 +172,59 @@ func Root(subcommands []*cobra.Command) *cobra.Command {
|
||||
}
|
||||
|
||||
cmd.AddCommand(subcommands...)
|
||||
fixUnknownSubcommandError(cmd.Commands())
|
||||
|
||||
cmd.SetUsageTemplate(usageTemplate())
|
||||
|
||||
cmd.PersistentFlags().String(varURL, "", "Specify the URL to your deployment.")
|
||||
cliflag.String(cmd.PersistentFlags(), varURL, "", envURL, "", "URL to a deployment.")
|
||||
cliflag.Bool(cmd.PersistentFlags(), varNoVersionCheck, "", envNoVersionCheck, false, "Suppress warning when client and server versions do not match.")
|
||||
cliflag.Bool(cmd.PersistentFlags(), varNoFeatureWarning, "", envNoFeatureWarning, false, "Suppress warnings about unlicensed features.")
|
||||
cliflag.String(cmd.PersistentFlags(), varToken, "", envSessionToken, "", fmt.Sprintf("Specify an authentication token. For security reasons setting %s is preferred.", envSessionToken))
|
||||
cliflag.String(cmd.PersistentFlags(), varAgentToken, "", "CODER_AGENT_TOKEN", "", "Specify an agent authentication token.")
|
||||
cliflag.String(cmd.PersistentFlags(), varAgentToken, "", "CODER_AGENT_TOKEN", "", "An agent authentication token.")
|
||||
_ = cmd.PersistentFlags().MarkHidden(varAgentToken)
|
||||
cliflag.String(cmd.PersistentFlags(), varAgentURL, "", "CODER_AGENT_URL", "", "Specify the URL for an agent to access your deployment.")
|
||||
cliflag.String(cmd.PersistentFlags(), varAgentURL, "", "CODER_AGENT_URL", "", "URL for an agent to access your deployment.")
|
||||
_ = cmd.PersistentFlags().MarkHidden(varAgentURL)
|
||||
cliflag.String(cmd.PersistentFlags(), varGlobalConfig, "", "CODER_CONFIG_DIR", configdir.LocalConfig("coderv2"), "Specify the path to the global `coder` config directory.")
|
||||
cliflag.String(cmd.PersistentFlags(), varGlobalConfig, "", "CODER_CONFIG_DIR", configdir.LocalConfig("coderv2"), "Path to the global `coder` config directory.")
|
||||
cliflag.StringArray(cmd.PersistentFlags(), varHeader, "", "CODER_HEADER", []string{}, "HTTP headers added to all requests. Provide as \"Key=Value\"")
|
||||
cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY.")
|
||||
_ = cmd.PersistentFlags().MarkHidden(varForceTty)
|
||||
cmd.PersistentFlags().Bool(varNoOpen, false, "Block automatically opening URLs in the browser.")
|
||||
_ = cmd.PersistentFlags().MarkHidden(varNoOpen)
|
||||
cliflag.Bool(cmd.PersistentFlags(), varVerbose, "v", "CODER_VERBOSE", false, "Enable verbose output")
|
||||
cliflag.Bool(cmd.PersistentFlags(), varVerbose, "v", "CODER_VERBOSE", false, "Enable verbose output.")
|
||||
cliflag.Bool(cmd.PersistentFlags(), varExperimental, "", envExperimental, false, "Enable experimental features. Experimental features are not ready for production.")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// fixUnknownSubcommandError modifies the provided commands so that the
|
||||
// ones with subcommands output the correct error message when an
|
||||
// unknown subcommand is invoked.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// unknown command "bad" for "coder templates"
|
||||
func fixUnknownSubcommandError(commands []*cobra.Command) {
|
||||
for _, sc := range commands {
|
||||
if sc.HasSubCommands() {
|
||||
if sc.Run == nil && sc.RunE == nil {
|
||||
if sc.Args != nil {
|
||||
// In case the developer does not know about this
|
||||
// behavior in Cobra they must verify correct
|
||||
// behavior. For instance, settings Args to
|
||||
// `cobra.ExactArgs(0)` will not give the same
|
||||
// message as `cobra.NoArgs`. Likewise, omitting the
|
||||
// run function will not give the wanted error.
|
||||
panic("developer error: subcommand has subcommands and Args but no Run or RunE")
|
||||
}
|
||||
sc.Args = cobra.NoArgs
|
||||
sc.Run = func(*cobra.Command, []string) {}
|
||||
}
|
||||
|
||||
fixUnknownSubcommandError(sc.Commands())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// versionCmd prints the coder version
|
||||
func versionCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
@@ -237,8 +279,32 @@ func CreateClient(cmd *cobra.Command) (*codersdk.Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
client, err := createUnauthenticatedClient(cmd, serverURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.SessionToken = token
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*codersdk.Client, error) {
|
||||
client := codersdk.New(serverURL)
|
||||
client.SessionToken = strings.TrimSpace(token)
|
||||
headers, err := cmd.Flags().GetStringArray(varHeader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transport := &headerTransport{
|
||||
transport: http.DefaultTransport,
|
||||
headers: map[string]string{},
|
||||
}
|
||||
for _, header := range headers {
|
||||
parts := strings.SplitN(header, "=", 2)
|
||||
if len(parts) < 2 {
|
||||
return nil, xerrors.Errorf("split header %q had less than two parts", header)
|
||||
}
|
||||
transport.headers[parts[0]] = parts[1]
|
||||
}
|
||||
client.HTTPClient.Transport = transport
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -521,11 +587,36 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error {
|
||||
defer cancel()
|
||||
|
||||
entitlements, err := client.Entitlements(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get entitlements to show warnings: %w", err)
|
||||
if err == nil {
|
||||
for _, w := range entitlements.Warnings {
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w))
|
||||
}
|
||||
}
|
||||
for _, w := range entitlements.Warnings {
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w))
|
||||
return nil
|
||||
}
|
||||
|
||||
type headerTransport struct {
|
||||
transport http.RoundTripper
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for k, v := range h.headers {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
return h.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// ExperimentalEnabled returns if the experimental feature flag is enabled.
|
||||
func ExperimentalEnabled(cmd *cobra.Command) bool {
|
||||
return cliflag.IsSetBool(cmd, varExperimental)
|
||||
}
|
||||
|
||||
// EnsureExperimental will ensure that the experimental feature flag is set if the given flag is set.
|
||||
func EnsureExperimental(cmd *cobra.Command, name string) error {
|
||||
_, set := cliflag.IsSet(cmd, name)
|
||||
if set && !ExperimentalEnabled(cmd) {
|
||||
return xerrors.Errorf("flag %s is set but requires flag --experimental or environment variable CODER_EXPERIMENTAL=true.", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -2,14 +2,18 @@ package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/buildinfo"
|
||||
"github.com/coder/coder/cli"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
@@ -129,4 +133,40 @@ func TestRoot(t *testing.T) {
|
||||
require.Contains(t, output, buildinfo.Version(), "has version")
|
||||
require.Contains(t, output, buildinfo.ExternalURL(), "has url")
|
||||
})
|
||||
|
||||
t.Run("Header", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
done := make(chan struct{})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "wow", r.Header.Get("X-Testing"))
|
||||
w.WriteHeader(http.StatusGone)
|
||||
select {
|
||||
case <-done:
|
||||
close(done)
|
||||
default:
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
buf := new(bytes.Buffer)
|
||||
cmd, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL)
|
||||
cmd.SetOut(buf)
|
||||
// This won't succeed, because we're using the login cmd to assert requests.
|
||||
_ = cmd.Execute()
|
||||
})
|
||||
|
||||
t.Run("Experimental", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd, _ := clitest.New(t, "--experimental")
|
||||
err := cmd.Execute()
|
||||
require.NoError(t, err)
|
||||
require.True(t, cli.ExperimentalEnabled(cmd))
|
||||
|
||||
cmd, _ = clitest.New(t, "help", "--verbose")
|
||||
_ = cmd.Execute()
|
||||
_, set := cliflag.IsSet(cmd, "verbose")
|
||||
require.True(t, set)
|
||||
require.ErrorContains(t, cli.EnsureExperimental(cmd, "verbose"), "--experimental")
|
||||
})
|
||||
}
|
||||
|
||||
+4
-1
@@ -57,7 +57,10 @@ func schedules() *cobra.Command {
|
||||
scheduleCmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "schedule { show | start | stop | override } <workspace>",
|
||||
Short: "Modify scheduled stop and start times for your workspace",
|
||||
Short: "Schedule automated start and stop times for workspaces",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmd.Help()
|
||||
},
|
||||
}
|
||||
|
||||
scheduleCmd.AddCommand(
|
||||
|
||||
@@ -61,7 +61,6 @@ func TestScheduleShow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
user = coderdtest.CreateFirstUser(t, client)
|
||||
version = coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
@@ -69,14 +68,13 @@ func TestScheduleShow(t *testing.T) {
|
||||
project = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
workspace = coderdtest.CreateWorkspace(t, client, user.OrganizationID, project.ID, func(cwr *codersdk.CreateWorkspaceRequest) {
|
||||
cwr.AutostartSchedule = nil
|
||||
cwr.TTLMillis = nil
|
||||
})
|
||||
_ = coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
cmdArgs = []string{"schedule", "show", workspace.Name}
|
||||
stdoutBuf = &bytes.Buffer{}
|
||||
)
|
||||
|
||||
// unset workspace TTL
|
||||
require.NoError(t, client.UpdateWorkspaceTTL(ctx, workspace.ID, codersdk.UpdateWorkspaceTTLRequest{TTLMillis: nil}))
|
||||
|
||||
cmd, root := clitest.New(t, cmdArgs...)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
cmd.SetOut(stdoutBuf)
|
||||
|
||||
+220
-330
@@ -5,7 +5,6 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -28,13 +28,11 @@ import (
|
||||
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/turn/v2"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
xgithub "golang.org/x/oauth2/github"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -46,19 +44,20 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/buildinfo"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/cli/config"
|
||||
"github.com/coder/coder/cli/deployment"
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/autobuild/executor"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/database/migrations"
|
||||
"github.com/coder/coder/coderd/devtunnel"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/prometheusmetrics"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/coderd/turnconn"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
@@ -70,67 +69,14 @@ import (
|
||||
)
|
||||
|
||||
// nolint:gocyclo
|
||||
func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
var (
|
||||
accessURL string
|
||||
address string
|
||||
autobuildPollInterval time.Duration
|
||||
derpServerEnabled bool
|
||||
derpServerRegionID int
|
||||
derpServerRegionCode string
|
||||
derpServerRegionName string
|
||||
derpServerSTUNAddrs []string
|
||||
derpConfigURL string
|
||||
promEnabled bool
|
||||
promAddress string
|
||||
pprofEnabled bool
|
||||
pprofAddress string
|
||||
cacheDir string
|
||||
inMemoryDatabase bool
|
||||
// provisionerDaemonCount is a uint8 to ensure a number > 0.
|
||||
provisionerDaemonCount uint8
|
||||
postgresURL string
|
||||
oauth2GithubClientID string
|
||||
oauth2GithubClientSecret string
|
||||
oauth2GithubAllowedOrganizations []string
|
||||
oauth2GithubAllowedTeams []string
|
||||
oauth2GithubAllowSignups bool
|
||||
oauth2GithubEnterpriseBaseURL string
|
||||
oidcAllowSignups bool
|
||||
oidcClientID string
|
||||
oidcClientSecret string
|
||||
oidcEmailDomain string
|
||||
oidcIssuerURL string
|
||||
oidcScopes []string
|
||||
tailscaleEnable bool
|
||||
telemetryEnable bool
|
||||
telemetryURL string
|
||||
tlsCertFile string
|
||||
tlsClientCAFile string
|
||||
tlsClientAuth string
|
||||
tlsEnable bool
|
||||
tlsKeyFile string
|
||||
tlsMinVersion string
|
||||
turnRelayAddress string
|
||||
tunnel bool
|
||||
stunServers []string
|
||||
trace bool
|
||||
secureAuthCookie bool
|
||||
sshKeygenAlgorithmRaw string
|
||||
autoImportTemplates []string
|
||||
spooky bool
|
||||
verbose bool
|
||||
metricsCacheRefreshInterval time.Duration
|
||||
agentStatRefreshInterval time.Duration
|
||||
)
|
||||
|
||||
func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command {
|
||||
root := &cobra.Command{
|
||||
Use: "server",
|
||||
Short: "Start a Coder server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
printLogo(cmd, spooky)
|
||||
printLogo(cmd)
|
||||
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()))
|
||||
if verbose {
|
||||
if dflags.Verbose.Value {
|
||||
logger = logger.Leveled(slog.LevelDebug)
|
||||
}
|
||||
|
||||
@@ -158,36 +104,51 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
defer http.DefaultClient.CloseIdleConnections()
|
||||
|
||||
var (
|
||||
tracerProvider *sdktrace.TracerProvider
|
||||
tracerProvider trace.TracerProvider
|
||||
err error
|
||||
sqlDriver = "postgres"
|
||||
)
|
||||
if trace {
|
||||
tracerProvider, err = tracing.TracerProvider(ctx, "coderd")
|
||||
|
||||
// Coder tracing should be disabled if telemetry is disabled unless
|
||||
// --telemetry-trace was explicitly provided.
|
||||
shouldCoderTrace := dflags.TelemetryEnable.Value && !isTest()
|
||||
// Only override if telemetryTraceEnable was specifically set.
|
||||
// By default we want it to be controlled by telemetryEnable.
|
||||
if cmd.Flags().Changed("telemetry-trace") {
|
||||
shouldCoderTrace = dflags.TelemetryTraceEnable.Value
|
||||
}
|
||||
|
||||
if dflags.TraceEnable.Value || shouldCoderTrace {
|
||||
sdkTracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{
|
||||
Default: dflags.TraceEnable.Value,
|
||||
Coder: shouldCoderTrace,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to start telemetry exporter", slog.Error(err))
|
||||
logger.Warn(ctx, "start telemetry exporter", slog.Error(err))
|
||||
} else {
|
||||
// allow time for traces to flush even if command context is canceled
|
||||
defer func() {
|
||||
_ = shutdownWithTimeout(tracerProvider, 5*time.Second)
|
||||
_ = shutdownWithTimeout(closeTracing, 5*time.Second)
|
||||
}()
|
||||
|
||||
d, err := tracing.PostgresDriver(tracerProvider, "coderd.database")
|
||||
d, err := tracing.PostgresDriver(sdkTracerProvider, "coderd.database")
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to start postgres tracing driver", slog.Error(err))
|
||||
logger.Warn(ctx, "start postgres tracing driver", slog.Error(err))
|
||||
} else {
|
||||
sqlDriver = d
|
||||
}
|
||||
|
||||
tracerProvider = sdkTracerProvider
|
||||
}
|
||||
}
|
||||
|
||||
config := createConfig(cmd)
|
||||
builtinPostgres := false
|
||||
// Only use built-in if PostgreSQL URL isn't specified!
|
||||
if !inMemoryDatabase && postgresURL == "" {
|
||||
if !dflags.InMemoryDatabase.Value && dflags.PostgresURL.Value == "" {
|
||||
var closeFunc func() error
|
||||
cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath())
|
||||
postgresURL, closeFunc, err = startBuiltinPostgres(ctx, config, logger)
|
||||
dflags.PostgresURL.Value, closeFunc, err = startBuiltinPostgres(ctx, config, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -200,17 +161,25 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
}()
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", address)
|
||||
listener, err := net.Listen("tcp", dflags.Address.Value)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("listen %q: %w", address, err)
|
||||
return xerrors.Errorf("listen %q: %w", dflags.Address.Value, err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
if tlsEnable {
|
||||
listener, err = configureTLS(listener, tlsMinVersion, tlsClientAuth, tlsCertFile, tlsKeyFile, tlsClientCAFile)
|
||||
var tlsConfig *tls.Config
|
||||
if dflags.TLSEnable.Value {
|
||||
tlsConfig, err = configureTLS(
|
||||
dflags.TLSMinVersion.Value,
|
||||
dflags.TLSClientAuth.Value,
|
||||
dflags.TLSCertFiles.Value,
|
||||
dflags.TLSKeyFiles.Value,
|
||||
dflags.TLSClientCAFile.Value,
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("configure tls: %w", err)
|
||||
}
|
||||
listener = tls.NewListener(listener, tlsConfig)
|
||||
}
|
||||
|
||||
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
|
||||
@@ -227,32 +196,29 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
Scheme: "http",
|
||||
Host: tcpAddr.String(),
|
||||
}
|
||||
if tlsEnable {
|
||||
if dflags.TLSEnable.Value {
|
||||
localURL.Scheme = "https"
|
||||
}
|
||||
if accessURL == "" {
|
||||
accessURL = localURL.String()
|
||||
}
|
||||
|
||||
var (
|
||||
ctxTunnel, closeTunnel = context.WithCancel(ctx)
|
||||
devTunnel *devtunnel.Tunnel
|
||||
devTunnelErr <-chan error
|
||||
tunnel *devtunnel.Tunnel
|
||||
tunnelErr <-chan error
|
||||
)
|
||||
defer closeTunnel()
|
||||
|
||||
// If we're attempting to tunnel in dev-mode, the access URL
|
||||
// needs to be changed to use the tunnel.
|
||||
if tunnel {
|
||||
cmd.Printf("Opening tunnel so workspaces can connect to your deployment\n")
|
||||
devTunnel, devTunnelErr, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel"))
|
||||
// If the access URL is empty, we attempt to run a reverse-proxy tunnel
|
||||
// to make the initial setup really simple.
|
||||
if dflags.AccessURL.Value == "" {
|
||||
cmd.Printf("Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL\n")
|
||||
tunnel, tunnelErr, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel"))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tunnel: %w", err)
|
||||
}
|
||||
accessURL = devTunnel.URL
|
||||
dflags.AccessURL.Value = tunnel.URL
|
||||
}
|
||||
|
||||
accessURLParsed, err := parseURL(ctx, accessURL)
|
||||
accessURLParsed, err := parseURL(ctx, dflags.AccessURL.Value)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse URL: %w", err)
|
||||
}
|
||||
@@ -275,11 +241,11 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
if isLocal {
|
||||
reason = "isn't externally reachable"
|
||||
}
|
||||
cmd.Printf("%s The access URL %s %s, this may cause unexpected problems when creating workspaces. Generate a unique *.try.coder.app URL with:\n", cliui.Styles.Warn.Render("Warning:"), cliui.Styles.Field.Render(accessURLParsed.String()), reason)
|
||||
cmd.Println(cliui.Styles.Code.Render(strings.Join(os.Args, " ") + " --tunnel"))
|
||||
cmd.Printf("%s The access URL %s %s, this may cause unexpected problems when creating workspaces. Generate a unique *.try.coder.app URL by not specifying an access URL.\n", cliui.Styles.Warn.Render("Warning:"), cliui.Styles.Field.Render(accessURLParsed.String()), reason)
|
||||
}
|
||||
|
||||
cmd.Printf("View the Web UI: %s\n", accessURLParsed.String())
|
||||
// A newline is added before for visibility in terminal output.
|
||||
cmd.Printf("\nView the Web UI: %s\n", accessURLParsed.String())
|
||||
|
||||
// Used for zero-trust instance identity with Google Cloud.
|
||||
googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication())
|
||||
@@ -287,33 +253,17 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
return err
|
||||
}
|
||||
|
||||
sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(sshKeygenAlgorithmRaw)
|
||||
sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(dflags.SSHKeygenAlgorithm.Value)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", sshKeygenAlgorithmRaw, err)
|
||||
}
|
||||
|
||||
turnServer, err := turnconn.New(&turn.RelayAddressGeneratorStatic{
|
||||
RelayAddress: net.ParseIP(turnRelayAddress),
|
||||
Address: turnRelayAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create turn server: %w", err)
|
||||
}
|
||||
defer turnServer.Close()
|
||||
|
||||
iceServers := make([]webrtc.ICEServer, 0)
|
||||
for _, stunServer := range stunServers {
|
||||
iceServers = append(iceServers, webrtc.ICEServer{
|
||||
URLs: []string{stunServer},
|
||||
})
|
||||
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", dflags.SSHKeygenAlgorithm.Value, err)
|
||||
}
|
||||
|
||||
// Validate provided auto-import templates.
|
||||
var (
|
||||
validatedAutoImportTemplates = make([]coderd.AutoImportTemplate, len(autoImportTemplates))
|
||||
seenValidatedAutoImportTemplates = make(map[coderd.AutoImportTemplate]struct{}, len(autoImportTemplates))
|
||||
validatedAutoImportTemplates = make([]coderd.AutoImportTemplate, len(dflags.AutoImportTemplates.Value))
|
||||
seenValidatedAutoImportTemplates = make(map[coderd.AutoImportTemplate]struct{}, len(dflags.AutoImportTemplates.Value))
|
||||
)
|
||||
for i, autoImportTemplate := range autoImportTemplates {
|
||||
for i, autoImportTemplate := range dflags.AutoImportTemplates.Value {
|
||||
var v coderd.AutoImportTemplate
|
||||
switch autoImportTemplate {
|
||||
case "kubernetes":
|
||||
@@ -330,62 +280,83 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
}
|
||||
|
||||
defaultRegion := &tailcfg.DERPRegion{
|
||||
RegionID: derpServerRegionID,
|
||||
RegionCode: derpServerRegionCode,
|
||||
RegionName: derpServerRegionName,
|
||||
EmbeddedRelay: true,
|
||||
RegionID: dflags.DerpServerRegionID.Value,
|
||||
RegionCode: dflags.DerpServerRegionCode.Value,
|
||||
RegionName: dflags.DerpServerRegionName.Value,
|
||||
Nodes: []*tailcfg.DERPNode{{
|
||||
Name: fmt.Sprintf("%db", derpServerRegionID),
|
||||
RegionID: derpServerRegionID,
|
||||
Name: fmt.Sprintf("%db", dflags.DerpServerRegionID.Value),
|
||||
RegionID: dflags.DerpServerRegionID.Value,
|
||||
HostName: accessURLParsed.Hostname(),
|
||||
DERPPort: accessURLPort,
|
||||
STUNPort: -1,
|
||||
ForceHTTP: accessURLParsed.Scheme == "http",
|
||||
}},
|
||||
}
|
||||
if !derpServerEnabled {
|
||||
if !dflags.DerpServerEnable.Value {
|
||||
defaultRegion = nil
|
||||
}
|
||||
derpMap, err := tailnet.NewDERPMap(ctx, defaultRegion, derpServerSTUNAddrs, derpConfigURL)
|
||||
derpMap, err := tailnet.NewDERPMap(ctx, defaultRegion, dflags.DerpServerSTUNAddresses.Value, dflags.DerpConfigURL.Value, dflags.DerpConfigPath.Value)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create derp map: %w", err)
|
||||
}
|
||||
|
||||
appHostname := strings.TrimSpace(dflags.WildcardAccessURL.Value)
|
||||
var appHostnameRegex *regexp.Regexp
|
||||
if appHostname != "" {
|
||||
appHostnameRegex, err = httpapi.CompileHostnamePattern(appHostname)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse wildcard access URL %q: %w", appHostname, err)
|
||||
}
|
||||
}
|
||||
|
||||
options := &coderd.Options{
|
||||
AccessURL: accessURLParsed,
|
||||
ICEServers: iceServers,
|
||||
AppHostname: appHostname,
|
||||
AppHostnameRegex: appHostnameRegex,
|
||||
Logger: logger.Named("coderd"),
|
||||
Database: databasefake.New(),
|
||||
DERPMap: derpMap,
|
||||
Pubsub: database.NewPubsubInMemory(),
|
||||
CacheDir: cacheDir,
|
||||
CacheDir: dflags.CacheDir.Value,
|
||||
GoogleTokenValidator: googleTokenValidator,
|
||||
SecureAuthCookie: secureAuthCookie,
|
||||
SecureAuthCookie: dflags.SecureAuthCookie.Value,
|
||||
SSHKeygenAlgorithm: sshKeygenAlgorithm,
|
||||
TailscaleEnable: tailscaleEnable,
|
||||
TURNServer: turnServer,
|
||||
TracerProvider: tracerProvider,
|
||||
Telemetry: telemetry.NewNoop(),
|
||||
AutoImportTemplates: validatedAutoImportTemplates,
|
||||
MetricsCacheRefreshInterval: metricsCacheRefreshInterval,
|
||||
AgentStatsRefreshInterval: agentStatRefreshInterval,
|
||||
MetricsCacheRefreshInterval: dflags.MetricsCacheRefreshInterval.Value,
|
||||
AgentStatsRefreshInterval: dflags.AgentStatRefreshInterval.Value,
|
||||
Experimental: ExperimentalEnabled(cmd),
|
||||
DeploymentFlags: dflags,
|
||||
}
|
||||
if tlsConfig != nil {
|
||||
options.TLSCertificates = tlsConfig.Certificates
|
||||
}
|
||||
|
||||
if oauth2GithubClientSecret != "" {
|
||||
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed, oauth2GithubClientID, oauth2GithubClientSecret, oauth2GithubAllowSignups, oauth2GithubAllowedOrganizations, oauth2GithubAllowedTeams, oauth2GithubEnterpriseBaseURL)
|
||||
if dflags.OAuth2GithubClientSecret.Value != "" {
|
||||
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed,
|
||||
dflags.OAuth2GithubClientID.Value,
|
||||
dflags.OAuth2GithubClientSecret.Value,
|
||||
dflags.OAuth2GithubAllowSignups.Value,
|
||||
dflags.OAuth2GithubAllowedOrganizations.Value,
|
||||
dflags.OAuth2GithubAllowedTeams.Value,
|
||||
dflags.OAuth2GithubEnterpriseBaseURL.Value,
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("configure github oauth2: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if oidcClientSecret != "" {
|
||||
if oidcClientID == "" {
|
||||
if dflags.OIDCClientSecret.Value != "" {
|
||||
if dflags.OIDCClientID.Value == "" {
|
||||
return xerrors.Errorf("OIDC client ID be set!")
|
||||
}
|
||||
if oidcIssuerURL == "" {
|
||||
if dflags.OIDCIssuerURL.Value == "" {
|
||||
return xerrors.Errorf("OIDC issuer URL must be set!")
|
||||
}
|
||||
|
||||
oidcProvider, err := oidc.NewProvider(ctx, oidcIssuerURL)
|
||||
oidcProvider, err := oidc.NewProvider(ctx, dflags.OIDCIssuerURL.Value)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("configure oidc provider: %w", err)
|
||||
}
|
||||
@@ -395,25 +366,25 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
}
|
||||
options.OIDCConfig = &coderd.OIDCConfig{
|
||||
OAuth2Config: &oauth2.Config{
|
||||
ClientID: oidcClientID,
|
||||
ClientSecret: oidcClientSecret,
|
||||
ClientID: dflags.OIDCClientID.Value,
|
||||
ClientSecret: dflags.OIDCClientSecret.Value,
|
||||
RedirectURL: redirectURL.String(),
|
||||
Endpoint: oidcProvider.Endpoint(),
|
||||
Scopes: oidcScopes,
|
||||
Scopes: dflags.OIDCScopes.Value,
|
||||
},
|
||||
Verifier: oidcProvider.Verifier(&oidc.Config{
|
||||
ClientID: oidcClientID,
|
||||
ClientID: dflags.OIDCClientID.Value,
|
||||
}),
|
||||
EmailDomain: oidcEmailDomain,
|
||||
AllowSignups: oidcAllowSignups,
|
||||
EmailDomain: dflags.OIDCEmailDomain.Value,
|
||||
AllowSignups: dflags.OIDCAllowSignups.Value,
|
||||
}
|
||||
}
|
||||
|
||||
if inMemoryDatabase {
|
||||
if dflags.InMemoryDatabase.Value {
|
||||
options.Database = databasefake.New()
|
||||
options.Pubsub = database.NewPubsubInMemory()
|
||||
} else {
|
||||
sqlDB, err := sql.Open(sqlDriver, postgresURL)
|
||||
sqlDB, err := sql.Open(sqlDriver, dflags.PostgresURL.Value)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("dial postgres: %w", err)
|
||||
}
|
||||
@@ -423,12 +394,12 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("ping postgres: %w", err)
|
||||
}
|
||||
err = database.MigrateUp(sqlDB)
|
||||
err = migrations.Up(sqlDB)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("migrate up: %w", err)
|
||||
}
|
||||
options.Database = database.New(sqlDB)
|
||||
options.Pubsub, err = database.NewPubsub(ctx, sqlDB, postgresURL)
|
||||
options.Pubsub, err = database.NewPubsub(ctx, sqlDB, dflags.PostgresURL.Value)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create pubsub: %w", err)
|
||||
}
|
||||
@@ -451,27 +422,27 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
}
|
||||
|
||||
// Parse the raw telemetry URL!
|
||||
telemetryURL, err := parseURL(ctx, telemetryURL)
|
||||
telemetryURL, err := parseURL(ctx, dflags.TelemetryURL.Value)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("parse telemetry url: %w", err)
|
||||
}
|
||||
// Disable telemetry if the in-memory database is used unless explicitly defined!
|
||||
if inMemoryDatabase && !cmd.Flags().Changed("telemetry") {
|
||||
telemetryEnable = false
|
||||
if dflags.InMemoryDatabase.Value && !cmd.Flags().Changed(dflags.TelemetryEnable.Flag) {
|
||||
dflags.TelemetryEnable.Value = false
|
||||
}
|
||||
if telemetryEnable {
|
||||
if dflags.TelemetryEnable.Value {
|
||||
options.Telemetry, err = telemetry.New(telemetry.Options{
|
||||
BuiltinPostgres: builtinPostgres,
|
||||
DeploymentID: deploymentID,
|
||||
Database: options.Database,
|
||||
Logger: logger.Named("telemetry"),
|
||||
URL: telemetryURL,
|
||||
GitHubOAuth: oauth2GithubClientID != "",
|
||||
OIDCAuth: oidcClientID != "",
|
||||
OIDCIssuerURL: oidcIssuerURL,
|
||||
Prometheus: promEnabled,
|
||||
STUN: len(stunServers) != 0,
|
||||
Tunnel: tunnel,
|
||||
GitHubOAuth: dflags.OAuth2GithubClientID.Value != "",
|
||||
OIDCAuth: dflags.OIDCClientID.Value != "",
|
||||
OIDCIssuerURL: dflags.OIDCIssuerURL.Value,
|
||||
Prometheus: dflags.PromEnabled.Value,
|
||||
STUN: len(dflags.DerpServerSTUNAddresses.Value) != 0,
|
||||
Tunnel: tunnel != nil,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create telemetry reporter: %w", err)
|
||||
@@ -481,11 +452,11 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
|
||||
// This prevents the pprof import from being accidentally deleted.
|
||||
_ = pprof.Handler
|
||||
if pprofEnabled {
|
||||
if dflags.PprofEnabled.Value {
|
||||
//nolint:revive
|
||||
defer serveHandler(ctx, logger, nil, pprofAddress, "pprof")()
|
||||
defer serveHandler(ctx, logger, nil, dflags.PprofAddress.Value, "pprof")()
|
||||
}
|
||||
if promEnabled {
|
||||
if dflags.PromEnabled.Value {
|
||||
options.PrometheusRegistry = prometheus.NewRegistry()
|
||||
closeUsersFunc, err := prometheusmetrics.ActiveUsers(ctx, options.PrometheusRegistry, options.Database, 0)
|
||||
if err != nil {
|
||||
@@ -502,14 +473,20 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
//nolint:revive
|
||||
defer serveHandler(ctx, logger, promhttp.InstrumentMetricHandler(
|
||||
options.PrometheusRegistry, promhttp.HandlerFor(options.PrometheusRegistry, promhttp.HandlerOpts{}),
|
||||
), promAddress, "prometheus")()
|
||||
), dflags.PromAddress.Value, "prometheus")()
|
||||
}
|
||||
|
||||
coderAPI := newAPI(options)
|
||||
defer coderAPI.Close()
|
||||
// We use a separate closer so the Enterprise API
|
||||
// can have it's own close functions. This is cleaner
|
||||
// than abstracting the Coder API itself.
|
||||
coderAPI, closer, err := newAPI(ctx, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer closer.Close()
|
||||
|
||||
client := codersdk.New(localURL)
|
||||
if tlsEnable {
|
||||
if dflags.TLSEnable.Value {
|
||||
// Secure transport isn't needed for locally communicating!
|
||||
client.HTTPClient.Transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
@@ -533,8 +510,8 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
_ = daemon.Close()
|
||||
}
|
||||
}()
|
||||
for i := 0; uint8(i) < provisionerDaemonCount; i++ {
|
||||
daemon, err := newProvisionerDaemon(ctx, coderAPI, logger, cacheDir, errCh, false)
|
||||
for i := 0; i < dflags.ProvisionerDaemonCount.Value; i++ {
|
||||
daemon, err := newProvisionerDaemon(ctx, coderAPI, logger, dflags.CacheDir.Value, errCh, false)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create provisioner daemon: %w", err)
|
||||
}
|
||||
@@ -552,30 +529,30 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
// These errors are typically noise like "TLS: EOF". Vault does similar:
|
||||
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
|
||||
ErrorLog: log.New(io.Discard, "", 0),
|
||||
Handler: coderAPI.Handler,
|
||||
Handler: coderAPI.RootHandler,
|
||||
BaseContext: func(_ net.Listener) context.Context {
|
||||
return shutdownConnsCtx
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
_ = shutdownWithTimeout(server, 5*time.Second)
|
||||
_ = shutdownWithTimeout(server.Shutdown, 5*time.Second)
|
||||
}()
|
||||
|
||||
eg := errgroup.Group{}
|
||||
eg.Go(func() error {
|
||||
// Make sure to close the tunnel listener if we exit so the
|
||||
// errgroup doesn't wait forever!
|
||||
if tunnel {
|
||||
defer devTunnel.Listener.Close()
|
||||
if tunnel != nil {
|
||||
defer tunnel.Listener.Close()
|
||||
}
|
||||
|
||||
return server.Serve(listener)
|
||||
})
|
||||
if tunnel {
|
||||
if tunnel != nil {
|
||||
eg.Go(func() error {
|
||||
defer listener.Close()
|
||||
|
||||
return server.Serve(devTunnel.Listener)
|
||||
return server.Serve(tunnel.Listener)
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
@@ -600,7 +577,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
return xerrors.Errorf("notify systemd: %w", err)
|
||||
}
|
||||
|
||||
autobuildPoller := time.NewTicker(autobuildPollInterval)
|
||||
autobuildPoller := time.NewTicker(dflags.AutobuildPollInterval.Value)
|
||||
defer autobuildPoller.Stop()
|
||||
autobuildExecutor := executor.New(ctx, options.Database, logger, autobuildPoller.C)
|
||||
autobuildExecutor.Run()
|
||||
@@ -620,7 +597,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render(
|
||||
"Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit",
|
||||
))
|
||||
case exitErr = <-devTunnelErr:
|
||||
case exitErr = <-tunnelErr:
|
||||
if exitErr == nil {
|
||||
exitErr = xerrors.New("dev tunnel closed unexpectedly")
|
||||
}
|
||||
@@ -646,9 +623,9 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
// in-flight requests, give in-flight requests 5 seconds to
|
||||
// complete.
|
||||
cmd.Println("Shutting down API server...")
|
||||
err = shutdownWithTimeout(server, 5*time.Second)
|
||||
err = shutdownWithTimeout(server.Shutdown, 3*time.Second)
|
||||
if err != nil {
|
||||
cmd.Printf("API server shutdown took longer than 5s: %s", err)
|
||||
cmd.Printf("API server shutdown took longer than 3s: %s\n", err)
|
||||
} else {
|
||||
cmd.Printf("Gracefully shut down API server\n")
|
||||
}
|
||||
@@ -665,10 +642,10 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
if verbose {
|
||||
if dflags.Verbose.Value {
|
||||
cmd.Printf("Shutting down provisioner daemon %d...\n", id)
|
||||
}
|
||||
err := shutdownWithTimeout(provisionerDaemon, 5*time.Second)
|
||||
err := shutdownWithTimeout(provisionerDaemon.Shutdown, 5*time.Second)
|
||||
if err != nil {
|
||||
cmd.PrintErrf("Failed to shutdown provisioner daemon %d: %s\n", id, err)
|
||||
return
|
||||
@@ -678,7 +655,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
cmd.PrintErrf("Close provisioner daemon %d: %s\n", id, err)
|
||||
return
|
||||
}
|
||||
if verbose {
|
||||
if dflags.Verbose.Value {
|
||||
cmd.Printf("Gracefully shut down provisioner daemon %d\n", id)
|
||||
}
|
||||
}()
|
||||
@@ -687,13 +664,13 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
|
||||
cmd.Println("Waiting for WebSocket connections to close...")
|
||||
_ = coderAPI.Close()
|
||||
cmd.Println("Done wainting for WebSocket connections")
|
||||
cmd.Println("Done waiting for WebSocket connections")
|
||||
|
||||
// Close tunnel after we no longer have in-flight connections.
|
||||
if tunnel {
|
||||
if tunnel != nil {
|
||||
cmd.Println("Waiting for tunnel to close...")
|
||||
closeTunnel()
|
||||
<-devTunnelErr
|
||||
<-tunnelErr
|
||||
cmd.Println("Done waiting for tunnel")
|
||||
}
|
||||
|
||||
@@ -703,6 +680,9 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
// Trigger context cancellation for any remaining services.
|
||||
cancel()
|
||||
|
||||
if xerrors.Is(exitErr, context.Canceled) {
|
||||
return nil
|
||||
}
|
||||
return exitErr
|
||||
},
|
||||
}
|
||||
@@ -727,7 +707,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := createConfig(cmd)
|
||||
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()))
|
||||
if verbose {
|
||||
if dflags.Verbose.Value {
|
||||
logger = logger.Leveled(slog.LevelDebug)
|
||||
}
|
||||
|
||||
@@ -748,110 +728,7 @@ func Server(newAPI func(*coderd.Options) *coderd.API) *cobra.Command {
|
||||
},
|
||||
})
|
||||
|
||||
cliflag.DurationVarP(root.Flags(), &autobuildPollInterval, "autobuild-poll-interval", "", "CODER_AUTOBUILD_POLL_INTERVAL", time.Minute, "Specifies the interval at which to poll for and execute automated workspace build operations.")
|
||||
cliflag.StringVarP(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder.")
|
||||
cliflag.StringVarP(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard.")
|
||||
cliflag.StringVarP(root.Flags(), &derpConfigURL, "derp-config-url", "", "CODER_DERP_CONFIG_URL", "",
|
||||
"Specifies a URL to periodically fetch a DERP map. See: https://tailscale.com/kb/1118/custom-derp-servers/")
|
||||
cliflag.BoolVarP(root.Flags(), &derpServerEnabled, "derp-server-enable", "", "CODER_DERP_SERVER_ENABLE", true, "Specifies whether to enable or disable the embedded DERP server.")
|
||||
cliflag.IntVarP(root.Flags(), &derpServerRegionID, "derp-server-region-id", "", "CODER_DERP_SERVER_REGION_ID", 999, "Specifies the region ID to use for the embedded DERP server.")
|
||||
cliflag.StringVarP(root.Flags(), &derpServerRegionCode, "derp-server-region-code", "", "CODER_DERP_SERVER_REGION_CODE", "coder", "Specifies the region code that is displayed in the Coder UI for the embedded DERP server.")
|
||||
cliflag.StringVarP(root.Flags(), &derpServerRegionName, "derp-server-region-name", "", "CODER_DERP_SERVER_REGION_NAME", "Coder Embedded DERP", "Specifies the region name that is displayed in the Coder UI for the embedded DERP server.")
|
||||
cliflag.StringArrayVarP(root.Flags(), &derpServerSTUNAddrs, "derp-server-stun-addresses", "", "CODER_DERP_SERVER_STUN_ADDRESSES", []string{
|
||||
"stun.l.google.com:19302",
|
||||
}, "Specify addresses for STUN servers to establish P2P connections. Set empty to disable P2P connections entirely.")
|
||||
|
||||
// Mark hidden while this feature is in testing!
|
||||
_ = root.Flags().MarkHidden("derp-config-url")
|
||||
_ = root.Flags().MarkHidden("derp-server-enable")
|
||||
_ = root.Flags().MarkHidden("derp-server-region-id")
|
||||
_ = root.Flags().MarkHidden("derp-server-region-code")
|
||||
_ = root.Flags().MarkHidden("derp-server-region-name")
|
||||
_ = root.Flags().MarkHidden("derp-server-stun-addresses")
|
||||
|
||||
cliflag.BoolVarP(root.Flags(), &promEnabled, "prometheus-enable", "", "CODER_PROMETHEUS_ENABLE", false, "Enable serving prometheus metrics on the addressdefined by --prometheus-address.")
|
||||
cliflag.StringVarP(root.Flags(), &promAddress, "prometheus-address", "", "CODER_PROMETHEUS_ADDRESS", "127.0.0.1:2112", "The address to serve prometheus metrics.")
|
||||
cliflag.BoolVarP(root.Flags(), &pprofEnabled, "pprof-enable", "", "CODER_PPROF_ENABLE", false, "Enable serving pprof metrics on the address defined by --pprof-address.")
|
||||
cliflag.StringVarP(root.Flags(), &pprofAddress, "pprof-address", "", "CODER_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.")
|
||||
defaultCacheDir := filepath.Join(os.TempDir(), "coder-cache")
|
||||
if dir := os.Getenv("CACHE_DIRECTORY"); dir != "" {
|
||||
// For compatibility with systemd.
|
||||
defaultCacheDir = dir
|
||||
}
|
||||
cliflag.StringVarP(root.Flags(), &cacheDir, "cache-dir", "", "CODER_CACHE_DIRECTORY", defaultCacheDir, "Specifies a directory to cache binaries for provision operations. If unspecified and $CACHE_DIRECTORY is set, it will be used for compatibility with systemd.")
|
||||
cliflag.BoolVarP(root.Flags(), &inMemoryDatabase, "in-memory", "", "CODER_INMEMORY", false,
|
||||
"Specifies whether data will be stored in an in-memory database.")
|
||||
_ = root.Flags().MarkHidden("in-memory")
|
||||
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "The URL of a PostgreSQL database to connect to. If empty, PostgreSQL binaries will be downloaded from Maven (https://repo1.maven.org/maven2) and store all data in the config root. Access the built-in database with \"coder server postgres-builtin-url\"")
|
||||
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 3, "The amount of provisioner daemons to create on start.")
|
||||
cliflag.StringVarP(root.Flags(), &oauth2GithubClientID, "oauth2-github-client-id", "", "CODER_OAUTH2_GITHUB_CLIENT_ID", "",
|
||||
"Specifies a client ID to use for oauth2 with GitHub.")
|
||||
cliflag.StringVarP(root.Flags(), &oauth2GithubClientSecret, "oauth2-github-client-secret", "", "CODER_OAUTH2_GITHUB_CLIENT_SECRET", "",
|
||||
"Specifies a client secret to use for oauth2 with GitHub.")
|
||||
cliflag.StringArrayVarP(root.Flags(), &oauth2GithubAllowedOrganizations, "oauth2-github-allowed-orgs", "", "CODER_OAUTH2_GITHUB_ALLOWED_ORGS", nil,
|
||||
"Specifies organizations the user must be a member of to authenticate with GitHub.")
|
||||
cliflag.StringArrayVarP(root.Flags(), &oauth2GithubAllowedTeams, "oauth2-github-allowed-teams", "", "CODER_OAUTH2_GITHUB_ALLOWED_TEAMS", nil,
|
||||
"Specifies teams inside organizations the user must be a member of to authenticate with GitHub. Formatted as: <organization-name>/<team-slug>.")
|
||||
cliflag.BoolVarP(root.Flags(), &oauth2GithubAllowSignups, "oauth2-github-allow-signups", "", "CODER_OAUTH2_GITHUB_ALLOW_SIGNUPS", false,
|
||||
"Specifies whether new users can sign up with GitHub.")
|
||||
cliflag.StringVarP(root.Flags(), &oauth2GithubEnterpriseBaseURL, "oauth2-github-enterprise-base-url", "", "CODER_OAUTH2_GITHUB_ENTERPRISE_BASE_URL", "",
|
||||
"Specifies the base URL of a GitHub Enterprise instance to use for oauth2.")
|
||||
cliflag.BoolVarP(root.Flags(), &oidcAllowSignups, "oidc-allow-signups", "", "CODER_OIDC_ALLOW_SIGNUPS", true,
|
||||
"Specifies whether new users can sign up with OIDC.")
|
||||
cliflag.StringVarP(root.Flags(), &oidcClientID, "oidc-client-id", "", "CODER_OIDC_CLIENT_ID", "",
|
||||
"Specifies a client ID to use for OIDC.")
|
||||
cliflag.StringVarP(root.Flags(), &oidcClientSecret, "oidc-client-secret", "", "CODER_OIDC_CLIENT_SECRET", "",
|
||||
"Specifies a client secret to use for OIDC.")
|
||||
cliflag.StringVarP(root.Flags(), &oidcEmailDomain, "oidc-email-domain", "", "CODER_OIDC_EMAIL_DOMAIN", "",
|
||||
"Specifies an email domain that clients authenticating with OIDC must match.")
|
||||
cliflag.StringVarP(root.Flags(), &oidcIssuerURL, "oidc-issuer-url", "", "CODER_OIDC_ISSUER_URL", "",
|
||||
"Specifies an issuer URL to use for OIDC.")
|
||||
cliflag.StringArrayVarP(root.Flags(), &oidcScopes, "oidc-scopes", "", "CODER_OIDC_SCOPES", []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
"Specifies scopes to grant when authenticating with OIDC.")
|
||||
cliflag.BoolVarP(root.Flags(), &tailscaleEnable, "tailscale", "", "CODER_TAILSCALE", false,
|
||||
"Specifies whether Tailscale networking is used for web applications and terminals.")
|
||||
_ = root.Flags().MarkHidden("tailscale")
|
||||
enableTelemetryByDefault := !isTest()
|
||||
cliflag.BoolVarP(root.Flags(), &telemetryEnable, "telemetry", "", "CODER_TELEMETRY", enableTelemetryByDefault, "Specifies whether telemetry is enabled or not. Coder collects anonymized usage data to help improve our product.")
|
||||
cliflag.StringVarP(root.Flags(), &telemetryURL, "telemetry-url", "", "CODER_TELEMETRY_URL", "https://telemetry.coder.com", "Specifies a URL to send telemetry to.")
|
||||
_ = root.Flags().MarkHidden("telemetry-url")
|
||||
cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
|
||||
cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
|
||||
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
|
||||
"To configure the listener to use a CA certificate, concatenate the primary certificate "+
|
||||
"and the CA certificate together. The primary certificate should appear first in the combined file")
|
||||
cliflag.StringVarP(root.Flags(), &tlsClientCAFile, "tls-client-ca-file", "", "CODER_TLS_CLIENT_CA_FILE", "",
|
||||
"PEM-encoded Certificate Authority file used for checking the authenticity of client")
|
||||
cliflag.StringVarP(root.Flags(), &tlsClientAuth, "tls-client-auth", "", "CODER_TLS_CLIENT_AUTH", "request",
|
||||
`Specifies the policy the server will follow for TLS Client Authentication. `+
|
||||
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify"`)
|
||||
cliflag.StringVarP(root.Flags(), &tlsKeyFile, "tls-key-file", "", "CODER_TLS_KEY_FILE", "",
|
||||
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file")
|
||||
cliflag.StringVarP(root.Flags(), &tlsMinVersion, "tls-min-version", "", "CODER_TLS_MIN_VERSION", "tls12",
|
||||
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`)
|
||||
cliflag.BoolVarP(root.Flags(), &tunnel, "tunnel", "", "CODER_TUNNEL", false,
|
||||
"Workspaces must be able to reach the `access-url`. This overrides your access URL with a public access URL that tunnels your Coder deployment.")
|
||||
cliflag.StringArrayVarP(root.Flags(), &stunServers, "stun-server", "", "CODER_STUN_SERVERS", []string{
|
||||
"stun:stun.l.google.com:19302",
|
||||
}, "Specify URLs for STUN servers to enable P2P connections.")
|
||||
cliflag.BoolVarP(root.Flags(), &trace, "trace", "", "CODER_TRACE", false, "Specifies if application tracing data is collected")
|
||||
cliflag.StringVarP(root.Flags(), &turnRelayAddress, "turn-relay-address", "", "CODER_TURN_RELAY_ADDRESS", "127.0.0.1",
|
||||
"Specifies the address to bind TURN connections.")
|
||||
cliflag.BoolVarP(root.Flags(), &secureAuthCookie, "secure-auth-cookie", "", "CODER_SECURE_AUTH_COOKIE", false, "Specifies if the 'Secure' property is set on browser session cookies")
|
||||
cliflag.StringVarP(root.Flags(), &sshKeygenAlgorithmRaw, "ssh-keygen-algorithm", "", "CODER_SSH_KEYGEN_ALGORITHM", "ed25519", "Specifies the algorithm to use for generating ssh keys. "+
|
||||
`Accepted values are "ed25519", "ecdsa", or "rsa4096"`)
|
||||
cliflag.StringArrayVarP(root.Flags(), &autoImportTemplates, "auto-import-template", "", "CODER_TEMPLATE_AUTOIMPORT", []string{}, "Which templates to auto-import. Available auto-importable templates are: kubernetes")
|
||||
cliflag.BoolVarP(root.Flags(), &spooky, "spooky", "", "", false, "Specifies spookiness level")
|
||||
_ = root.Flags().MarkHidden("spooky")
|
||||
cliflag.BoolVarP(root.Flags(), &verbose, "verbose", "v", "CODER_VERBOSE", false, "Enables verbose logging.")
|
||||
|
||||
// These metrics flags are for manually testing the metric system.
|
||||
// The defaults should be acceptable for any Coder deployment of any
|
||||
// reasonable size.
|
||||
cliflag.DurationVarP(root.Flags(), &metricsCacheRefreshInterval, "metrics-cache-refresh-interval", "", "CODER_METRICS_CACHE_REFRESH_INTERVAL", time.Hour, "How frequently metrics are refreshed")
|
||||
_ = root.Flags().MarkHidden("metrics-cache-refresh-interval")
|
||||
cliflag.DurationVarP(root.Flags(), &agentStatRefreshInterval, "agent-stats-refresh-interval", "", "CODER_AGENT_STATS_REFRESH_INTERVAL", time.Minute*10, "How frequently agent stats are recorded")
|
||||
_ = root.Flags().MarkHidden("agent-stats-report-interval")
|
||||
deployment.AttachFlags(root.Flags(), dflags, false)
|
||||
|
||||
return root
|
||||
}
|
||||
@@ -904,15 +781,20 @@ func isLocalURL(ctx context.Context, u *url.URL) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func shutdownWithTimeout(s interface{ Shutdown(context.Context) error }, timeout time.Duration) error {
|
||||
func shutdownWithTimeout(shutdown func(context.Context) error, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
return s.Shutdown(ctx)
|
||||
return shutdown(ctx)
|
||||
}
|
||||
|
||||
// nolint:revive
|
||||
func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API,
|
||||
logger slog.Logger, cacheDir string, errCh chan error, dev bool,
|
||||
func newProvisionerDaemon(
|
||||
ctx context.Context,
|
||||
coderAPI *coderd.API,
|
||||
logger slog.Logger,
|
||||
cacheDir string,
|
||||
errCh chan error,
|
||||
dev bool,
|
||||
) (srv *provisionerd.Server, err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
@@ -985,29 +867,41 @@ func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API,
|
||||
UpdateInterval: 500 * time.Millisecond,
|
||||
Provisioners: provisioners,
|
||||
WorkDirectory: tempDir,
|
||||
Tracer: coderAPI.TracerProvider,
|
||||
}), nil
|
||||
}
|
||||
|
||||
// nolint: revive
|
||||
func printLogo(cmd *cobra.Command, spooky bool) {
|
||||
if spooky {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), `▄████▄ ▒█████ ▓█████▄ ▓█████ ██▀███
|
||||
▒██▀ ▀█ ▒██▒ ██▒▒██▀ ██▌▓█ ▀ ▓██ ▒ ██▒
|
||||
▒▓█ ▄ ▒██░ ██▒░██ █▌▒███ ▓██ ░▄█ ▒
|
||||
▒▓▓▄ ▄██▒▒██ ██░░▓█▄ ▌▒▓█ ▄ ▒██▀▀█▄
|
||||
▒ ▓███▀ ░░ ████▓▒░░▒████▓ ░▒████▒░██▓ ▒██▒
|
||||
░ ░▒ ▒ ░░ ▒░▒░▒░ ▒▒▓ ▒ ░░ ▒░ ░░ ▒▓ ░▒▓░
|
||||
░ ▒ ░ ▒ ▒░ ░ ▒ ▒ ░ ░ ░ ░▒ ░ ▒░
|
||||
░ ░ ░ ░ ▒ ░ ░ ░ ░ ░░ ░
|
||||
░ ░ ░ ░ ░ ░ ░ ░
|
||||
░ ░
|
||||
`)
|
||||
return
|
||||
}
|
||||
func printLogo(cmd *cobra.Command) {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s - Remote development on your infrastucture\n", cliui.Styles.Bold.Render("Coder "+buildinfo.Version()))
|
||||
}
|
||||
|
||||
func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFile, tlsKeyFile, tlsClientCAFile string) (net.Listener, error) {
|
||||
func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) {
|
||||
if len(tlsCertFiles) != len(tlsKeyFiles) {
|
||||
return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times")
|
||||
}
|
||||
if len(tlsCertFiles) == 0 {
|
||||
return nil, xerrors.New("--tls-cert-file is required when tls is enabled")
|
||||
}
|
||||
if len(tlsKeyFiles) == 0 {
|
||||
return nil, xerrors.New("--tls-key-file is required when tls is enabled")
|
||||
}
|
||||
|
||||
certs := make([]tls.Certificate, len(tlsCertFiles))
|
||||
for i := range tlsCertFiles {
|
||||
certFile, keyFile := tlsCertFiles[i], tlsKeyFiles[i]
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("load TLS key pair %d (%q, %q): %w", i, certFile, keyFile, err)
|
||||
}
|
||||
|
||||
certs[i] = cert
|
||||
}
|
||||
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
func configureTLS(tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string) (*tls.Config, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
@@ -1039,36 +933,32 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi
|
||||
return nil, xerrors.Errorf("unrecognized tls client auth: %q", tlsClientAuth)
|
||||
}
|
||||
|
||||
if tlsCertFile == "" {
|
||||
return nil, xerrors.New("tls-cert-file is required when tls is enabled")
|
||||
}
|
||||
if tlsKeyFile == "" {
|
||||
return nil, xerrors.New("tls-key-file is required when tls is enabled")
|
||||
certs, err := loadCertificates(tlsCertFiles, tlsKeyFiles)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("load certificates: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = certs
|
||||
tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
// If there's only one certificate, return it.
|
||||
if len(certs) == 1 {
|
||||
return &certs[0], nil
|
||||
}
|
||||
|
||||
certPEMBlock, err := os.ReadFile(tlsCertFile)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read file %q: %w", tlsCertFile, err)
|
||||
}
|
||||
keyPEMBlock, err := os.ReadFile(tlsKeyFile)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("read file %q: %w", tlsKeyFile, err)
|
||||
}
|
||||
keyBlock, _ := pem.Decode(keyPEMBlock)
|
||||
if keyBlock == nil {
|
||||
return nil, xerrors.New("decoded pem is blank")
|
||||
}
|
||||
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create key pair: %w", err)
|
||||
}
|
||||
tlsConfig.GetCertificate = func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return &cert, nil
|
||||
}
|
||||
// Expensively check which certificate matches the client hello.
|
||||
for _, cert := range certs {
|
||||
cert := cert
|
||||
if err := hi.SupportsCertificate(&cert); err == nil {
|
||||
return &cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
certPool := x509.NewCertPool()
|
||||
certPool.AppendCertsFromPEM(certPEMBlock)
|
||||
tlsConfig.RootCAs = certPool
|
||||
// Return the first certificate if we have one, or return nil so the
|
||||
// server doesn't fail.
|
||||
if len(certs) > 0 {
|
||||
return &certs[0], nil
|
||||
}
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
if tlsClientCAFile != "" {
|
||||
caPool := x509.NewCertPool()
|
||||
@@ -1082,7 +972,7 @@ func configureTLS(listener net.Listener, tlsMinVersion, tlsClientAuth, tlsCertFi
|
||||
tlsConfig.ClientCAs = caPool
|
||||
}
|
||||
|
||||
return tls.NewListener(listener, tlsConfig), nil
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) {
|
||||
|
||||
+185
-34
@@ -21,10 +21,11 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
@@ -55,6 +56,7 @@ func TestServer(t *testing.T) {
|
||||
root, cfg := clitest.New(t,
|
||||
"server",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--postgres-url", connectionURL,
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
@@ -73,7 +75,7 @@ func TestServer(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
cancelFunc()
|
||||
require.ErrorIs(t, <-errC, context.Canceled)
|
||||
require.NoError(t, <-errC)
|
||||
})
|
||||
t.Run("BuiltinPostgres", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -86,6 +88,7 @@ func TestServer(t *testing.T) {
|
||||
root, cfg := clitest.New(t,
|
||||
"server",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
pty := ptytest.New(t)
|
||||
@@ -101,7 +104,7 @@ func TestServer(t *testing.T) {
|
||||
return err == nil && rawURL != ""
|
||||
}, 3*time.Minute, testutil.IntervalFast, "failed to get access URL")
|
||||
cancelFunc()
|
||||
require.ErrorIs(t, <-errC, context.Canceled)
|
||||
require.NoError(t, <-errC)
|
||||
})
|
||||
t.Run("BuiltinPostgresURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -129,8 +132,9 @@ func TestServer(t *testing.T) {
|
||||
"--access-url", "localhost:3000/",
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
buf := newThreadSafeBuffer()
|
||||
root.SetOutput(buf)
|
||||
pty := ptytest.New(t)
|
||||
root.SetIn(pty.Input())
|
||||
root.SetOut(pty.Output())
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- root.ExecuteContext(ctx)
|
||||
@@ -139,10 +143,11 @@ func TestServer(t *testing.T) {
|
||||
// Just wait for startup
|
||||
_ = waitAccessURL(t, cfg)
|
||||
|
||||
pty.ExpectMatch("this may cause unexpected problems when creating workspaces")
|
||||
pty.ExpectMatch("View the Web UI: http://localhost:3000/")
|
||||
|
||||
cancelFunc()
|
||||
require.ErrorIs(t, <-errC, context.Canceled)
|
||||
require.Contains(t, buf.String(), "this may cause unexpected problems when creating workspaces")
|
||||
require.Contains(t, buf.String(), "View the Web UI: http://localhost:3000/\n")
|
||||
require.NoError(t, <-errC)
|
||||
})
|
||||
|
||||
// Validate that an https scheme is prepended to a remote access URL
|
||||
@@ -156,11 +161,13 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--access-url", "foobarbaz.mydomain",
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
buf := newThreadSafeBuffer()
|
||||
root.SetOutput(buf)
|
||||
pty := ptytest.New(t)
|
||||
root.SetIn(pty.Input())
|
||||
root.SetOut(pty.Output())
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- root.ExecuteContext(ctx)
|
||||
@@ -169,10 +176,11 @@ func TestServer(t *testing.T) {
|
||||
// Just wait for startup
|
||||
_ = waitAccessURL(t, cfg)
|
||||
|
||||
pty.ExpectMatch("this may cause unexpected problems when creating workspaces")
|
||||
pty.ExpectMatch("View the Web UI: https://foobarbaz.mydomain")
|
||||
|
||||
cancelFunc()
|
||||
require.ErrorIs(t, <-errC, context.Canceled)
|
||||
require.Contains(t, buf.String(), "this may cause unexpected problems when creating workspaces")
|
||||
require.Contains(t, buf.String(), "View the Web UI: https://foobarbaz.mydomain\n")
|
||||
require.NoError(t, <-errC)
|
||||
})
|
||||
|
||||
t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) {
|
||||
@@ -184,11 +192,13 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--access-url", "https://google.com",
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
buf := newThreadSafeBuffer()
|
||||
root.SetOutput(buf)
|
||||
pty := ptytest.New(t)
|
||||
root.SetIn(pty.Input())
|
||||
root.SetOut(pty.Output())
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- root.ExecuteContext(ctx)
|
||||
@@ -197,10 +207,10 @@ func TestServer(t *testing.T) {
|
||||
// Just wait for startup
|
||||
_ = waitAccessURL(t, cfg)
|
||||
|
||||
pty.ExpectMatch("View the Web UI: https://google.com")
|
||||
|
||||
cancelFunc()
|
||||
require.ErrorIs(t, <-errC, context.Canceled)
|
||||
require.NotContains(t, buf.String(), "this may cause unexpected problems when creating workspaces")
|
||||
require.Contains(t, buf.String(), "View the Web UI: https://google.com\n")
|
||||
require.NoError(t, <-errC)
|
||||
})
|
||||
|
||||
t.Run("TLSBadVersion", func(t *testing.T) {
|
||||
@@ -212,6 +222,7 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--tls-enable",
|
||||
"--tls-min-version", "tls9",
|
||||
"--cache-dir", t.TempDir(),
|
||||
@@ -228,6 +239,7 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--tls-enable",
|
||||
"--tls-client-auth", "something",
|
||||
"--cache-dir", t.TempDir(),
|
||||
@@ -235,20 +247,65 @@ func TestServer(t *testing.T) {
|
||||
err := root.ExecuteContext(ctx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
t.Run("TLSNoCertFile", func(t *testing.T) {
|
||||
t.Run("TLSInvalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
root, _ := clitest.New(t,
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--tls-enable",
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
err := root.ExecuteContext(ctx)
|
||||
require.Error(t, err)
|
||||
cert1Path, key1Path := generateTLSCertificate(t)
|
||||
cert2Path, key2Path := generateTLSCertificate(t)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
args []string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "NoCertAndKey",
|
||||
args: []string{"--tls-enable"},
|
||||
errContains: "--tls-cert-file is required when tls is enabled",
|
||||
},
|
||||
{
|
||||
name: "NoCert",
|
||||
args: []string{"--tls-enable", "--tls-key-file", key1Path},
|
||||
errContains: "--tls-cert-file and --tls-key-file must be used the same amount of times",
|
||||
},
|
||||
{
|
||||
name: "NoKey",
|
||||
args: []string{"--tls-enable", "--tls-cert-file", cert1Path},
|
||||
errContains: "--tls-cert-file and --tls-key-file must be used the same amount of times",
|
||||
},
|
||||
{
|
||||
name: "MismatchedCount",
|
||||
args: []string{"--tls-enable", "--tls-cert-file", cert1Path, "--tls-key-file", key1Path, "--tls-cert-file", cert2Path},
|
||||
errContains: "--tls-cert-file and --tls-key-file must be used the same amount of times",
|
||||
},
|
||||
{
|
||||
name: "MismatchedCertAndKey",
|
||||
args: []string{"--tls-enable", "--tls-cert-file", cert1Path, "--tls-key-file", key2Path},
|
||||
errContains: "load TLS key pair",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
c := c
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
args := []string{
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--cache-dir", t.TempDir(),
|
||||
}
|
||||
args = append(args, c.args...)
|
||||
root, _ := clitest.New(t, args...)
|
||||
err := root.ExecuteContext(ctx)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, c.errContains)
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("TLSValid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -260,6 +317,7 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--tls-enable",
|
||||
"--tls-cert-file", certPath,
|
||||
"--tls-key-file", keyPath,
|
||||
@@ -286,7 +344,88 @@ func TestServer(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
cancelFunc()
|
||||
require.ErrorIs(t, <-errC, context.Canceled)
|
||||
require.NoError(t, <-errC)
|
||||
})
|
||||
t.Run("TLSValidMultiple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
cert1Path, key1Path := generateTLSCertificate(t, "alpaca.com")
|
||||
cert2Path, key2Path := generateTLSCertificate(t, "*.llama.com")
|
||||
root, cfg := clitest.New(t,
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--tls-enable",
|
||||
"--tls-cert-file", cert1Path,
|
||||
"--tls-key-file", key1Path,
|
||||
"--tls-cert-file", cert2Path,
|
||||
"--tls-key-file", key2Path,
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- root.ExecuteContext(ctx)
|
||||
}()
|
||||
accessURL := waitAccessURL(t, cfg)
|
||||
require.Equal(t, "https", accessURL.Scheme)
|
||||
originalHost := accessURL.Host
|
||||
|
||||
var (
|
||||
expectAddr string
|
||||
dials int64
|
||||
)
|
||||
client := codersdk.New(accessURL)
|
||||
client.HTTPClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
atomic.AddInt64(&dials, 1)
|
||||
assert.Equal(t, expectAddr, addr)
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Always connect to the accessURL ip:port regardless of
|
||||
// hostname.
|
||||
conn, err := tls.Dial(network, originalHost, &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
//nolint:gosec
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: host,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We can't call conn.VerifyHostname because it requires
|
||||
// that the certificates are valid, so we call
|
||||
// VerifyHostname on the first certificate instead.
|
||||
require.Len(t, conn.ConnectionState().PeerCertificates, 1)
|
||||
err = conn.ConnectionState().PeerCertificates[0].VerifyHostname(host)
|
||||
assert.NoError(t, err, "invalid cert common name")
|
||||
return conn, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Use the first certificate and hostname.
|
||||
client.URL.Host = "alpaca.com:443"
|
||||
expectAddr = "alpaca.com:443"
|
||||
_, err := client.HasFirstUser(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, atomic.LoadInt64(&dials))
|
||||
|
||||
// Use the second certificate (wildcard) and hostname.
|
||||
client.URL.Host = "hi.llama.com:443"
|
||||
expectAddr = "hi.llama.com:443"
|
||||
_, err = client.HasFirstUser(ctx)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, atomic.LoadInt64(&dials))
|
||||
|
||||
cancelFunc()
|
||||
require.NoError(t, <-errC)
|
||||
})
|
||||
// This cannot be ran in parallel because it uses a signal.
|
||||
//nolint:paralleltest
|
||||
@@ -302,6 +441,7 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--provisioner-daemons", "1",
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
@@ -317,7 +457,7 @@ func TestServer(t *testing.T) {
|
||||
// We cannot send more signals here, because it's possible Coder
|
||||
// has already exited, which could cause the test to fail due to interrupt.
|
||||
err = <-serverErr
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("TracerNoLeak", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -328,6 +468,7 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--trace=true",
|
||||
"--cache-dir", t.TempDir(),
|
||||
)
|
||||
@@ -336,7 +477,7 @@ func TestServer(t *testing.T) {
|
||||
errC <- root.ExecuteContext(ctx)
|
||||
}()
|
||||
cancelFunc()
|
||||
require.ErrorIs(t, <-errC, context.Canceled)
|
||||
require.NoError(t, <-errC)
|
||||
require.Error(t, goleak.Find())
|
||||
})
|
||||
t.Run("Telemetry", func(t *testing.T) {
|
||||
@@ -365,6 +506,7 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--telemetry",
|
||||
"--telemetry-url", server.URL,
|
||||
"--cache-dir", t.TempDir(),
|
||||
@@ -395,6 +537,7 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--provisioner-daemons", "1",
|
||||
"--prometheus-enable",
|
||||
"--prometheus-address", ":"+strconv.Itoa(randomPort),
|
||||
@@ -447,6 +590,7 @@ func TestServer(t *testing.T) {
|
||||
"server",
|
||||
"--in-memory",
|
||||
"--address", ":0",
|
||||
"--access-url", "example.com",
|
||||
"--oauth2-github-client-id", "fake",
|
||||
"--oauth2-github-client-secret", "fake",
|
||||
"--oauth2-github-enterprise-base-url", fakeRedirect,
|
||||
@@ -475,16 +619,22 @@ func TestServer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func generateTLSCertificate(t testing.TB) (certPath, keyPath string) {
|
||||
func generateTLSCertificate(t testing.TB, commonName ...string) (certPath, keyPath string) {
|
||||
dir := t.TempDir()
|
||||
|
||||
commonNameStr := "localhost"
|
||||
if len(commonName) > 0 {
|
||||
commonNameStr = commonName[0]
|
||||
}
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Acme Co"},
|
||||
CommonName: commonNameStr,
|
||||
},
|
||||
DNSNames: []string{commonNameStr},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 180),
|
||||
|
||||
@@ -493,6 +643,7 @@ func generateTLSCertificate(t testing.TB) (certPath, keyPath string) {
|
||||
BasicConstraintsValid: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
certFile, err := os.CreateTemp(dir, "")
|
||||
|
||||
+2
-6
@@ -11,7 +11,7 @@ func show() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "show <workspace>",
|
||||
Short: "Show details of a workspace's resources and agents",
|
||||
Short: "Display details of a workspace's resources and agents",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
@@ -26,11 +26,7 @@ func show() *cobra.Command {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace: %w", err)
|
||||
}
|
||||
resources, err := client.WorkspaceResourcesByBuild(cmd.Context(), workspace.LatestBuild.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get workspace resources: %w", err)
|
||||
}
|
||||
return cliui.WorkspaceResources(cmd.OutOrStdout(), resources, cliui.WorkspaceResourcesOptions{
|
||||
return cliui.WorkspaceResources(cmd.OutOrStdout(), workspace.LatestBuild.Resources, cliui.WorkspaceResourcesOptions{
|
||||
WorkspaceName: workspace.Name,
|
||||
ServerVersion: buildInfo.Version,
|
||||
})
|
||||
|
||||
+35
-27
@@ -11,7 +11,7 @@ import (
|
||||
tsspeedtest "tailscale.com/net/speedtest"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/agent"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/codersdk"
|
||||
@@ -27,7 +27,7 @@ func speedtest() *cobra.Command {
|
||||
Annotations: workspaceCommand,
|
||||
Use: "speedtest <workspace>",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Short: "Run a speed test from your machine to the workspace.",
|
||||
Short: "Run upload and download tests from your machine to a workspace",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
@@ -51,36 +51,44 @@ func speedtest() *cobra.Command {
|
||||
if err != nil {
|
||||
return xerrors.Errorf("await agent: %w", err)
|
||||
}
|
||||
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
|
||||
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()))
|
||||
if cliflag.IsSetBool(cmd, varVerbose) {
|
||||
logger = logger.Leveled(slog.LevelDebug)
|
||||
}
|
||||
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{
|
||||
Logger: logger,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
if direct {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
dur, err := conn.Ping()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
tc, _ := conn.(*agent.TailnetConn)
|
||||
status := tc.Status()
|
||||
if len(status.Peers()) != 1 {
|
||||
continue
|
||||
}
|
||||
peer := status.Peer[status.Peers()[0]]
|
||||
if peer.CurAddr == "" {
|
||||
cmd.Printf("Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay)
|
||||
continue
|
||||
}
|
||||
break
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
dur, err := conn.Ping(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
status := conn.Status()
|
||||
if len(status.Peers()) != 1 {
|
||||
continue
|
||||
}
|
||||
peer := status.Peer[status.Peers()[0]]
|
||||
if peer.CurAddr == "" && direct {
|
||||
cmd.Printf("Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay)
|
||||
continue
|
||||
}
|
||||
via := peer.Relay
|
||||
if via == "" {
|
||||
via = "direct"
|
||||
}
|
||||
cmd.Printf("%dms via %s\n", dur.Milliseconds(), via)
|
||||
break
|
||||
}
|
||||
dir := tsspeedtest.Download
|
||||
if reverse {
|
||||
|
||||
@@ -25,12 +25,11 @@ func TestSpeedtest(t *testing.T) {
|
||||
agentClient.SessionToken = agentToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
defer agentCloser.Close()
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
|
||||
cmd, root := clitest.New(t, "speedtest", workspace.Name)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
+10
-18
@@ -20,9 +20,6 @@ import (
|
||||
"golang.org/x/term"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/coderd/autobuild/notify"
|
||||
@@ -43,12 +40,11 @@ func ssh() *cobra.Command {
|
||||
forwardAgent bool
|
||||
identityAgent string
|
||||
wsPollInterval time.Duration
|
||||
wireguard bool
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "ssh <workspace>",
|
||||
Short: "SSH into a workspace",
|
||||
Short: "Start a shell into a workspace",
|
||||
Args: cobra.ArbitraryArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
@@ -88,12 +84,7 @@ func ssh() *cobra.Command {
|
||||
return xerrors.Errorf("await agent: %w", err)
|
||||
}
|
||||
|
||||
var conn agent.Conn
|
||||
if !wireguard {
|
||||
conn, err = client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
|
||||
} else {
|
||||
conn, err = client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, workspaceAgent.ID)
|
||||
}
|
||||
conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -195,6 +186,13 @@ func ssh() *cobra.Command {
|
||||
// shutdown of services.
|
||||
defer cancel()
|
||||
|
||||
if validOut {
|
||||
// Set initial window size.
|
||||
width, height, err := term.GetSize(int(stdoutFile.Fd()))
|
||||
if err == nil {
|
||||
_ = sshSession.WindowChange(height, width)
|
||||
}
|
||||
}
|
||||
err = sshSession.Wait()
|
||||
if err != nil {
|
||||
// If the connection drops unexpectedly, we get an ExitMissingError but no other
|
||||
@@ -214,9 +212,6 @@ func ssh() *cobra.Command {
|
||||
cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK")
|
||||
cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled")
|
||||
cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.")
|
||||
cliflag.BoolVarP(cmd.Flags(), &wireguard, "wireguard", "", "CODER_SSH_WIREGUARD", false, "Whether to use Wireguard for SSH tunneling.")
|
||||
_ = cmd.Flags().MarkHidden("wireguard")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -264,10 +259,7 @@ func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *coder
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("workspace %q is being deleted", workspace.Name)
|
||||
}
|
||||
|
||||
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
if err != nil {
|
||||
return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.Errorf("fetch workspace resources: %w", err)
|
||||
}
|
||||
resources := workspace.LatestBuild.Resources
|
||||
|
||||
agents := make([]codersdk.WorkspaceAgent, 0)
|
||||
for _, resource := range resources {
|
||||
|
||||
@@ -90,7 +90,6 @@ func TestSSH(t *testing.T) {
|
||||
agentClient.SessionToken = agentToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
@@ -112,7 +111,6 @@ func TestSSH(t *testing.T) {
|
||||
agentClient.SessionToken = agentToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
@@ -181,7 +179,6 @@ func TestSSH(t *testing.T) {
|
||||
agentClient.SessionToken = agentToken
|
||||
agentCloser := agent.New(agent.Options{
|
||||
FetchMetadata: agentClient.WorkspaceAgentMetadata,
|
||||
WebRTCDialer: agentClient.ListenWorkspaceAgent,
|
||||
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
|
||||
Logger: slogtest.Make(t, nil).Named("agent"),
|
||||
})
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@ func start() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "start <workspace>",
|
||||
Short: "Build a workspace with the start state",
|
||||
Short: "Start a workspace",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
|
||||
@@ -17,6 +17,9 @@ func state() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "state",
|
||||
Short: "Manually manage Terraform state to fix broken workspaces",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmd.Help()
|
||||
},
|
||||
}
|
||||
cmd.AddCommand(statePull(), statePush())
|
||||
return cmd
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@ func stop() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Annotations: workspaceCommand,
|
||||
Use: "stop <workspace>",
|
||||
Short: "Build a workspace with the stop state",
|
||||
Short: "Stop a workspace",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
_, err := cliui.Prompt(cmd, cliui.PromptOptions{
|
||||
|
||||
@@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
"github.com/google/uuid"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
@@ -90,7 +92,7 @@ func templateCreate() *cobra.Command {
|
||||
Client: client,
|
||||
Organization: organization,
|
||||
Provisioner: database.ProvisionerType(provisioner),
|
||||
FileHash: resp.Hash,
|
||||
FileID: resp.ID,
|
||||
ParameterFile: parameterFile,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -143,10 +145,11 @@ func templateCreate() *cobra.Command {
|
||||
}
|
||||
|
||||
type createValidTemplateVersionArgs struct {
|
||||
Name string
|
||||
Client *codersdk.Client
|
||||
Organization codersdk.Organization
|
||||
Provisioner database.ProvisionerType
|
||||
FileHash string
|
||||
FileID uuid.UUID
|
||||
ParameterFile string
|
||||
// Template is only required if updating a template's active version.
|
||||
Template *codersdk.Template
|
||||
@@ -161,8 +164,9 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
|
||||
client := args.Client
|
||||
|
||||
req := codersdk.CreateTemplateVersionRequest{
|
||||
Name: args.Name,
|
||||
StorageMethod: codersdk.ProvisionerStorageMethodFile,
|
||||
StorageSource: args.FileHash,
|
||||
FileID: args.FileID,
|
||||
Provisioner: codersdk.ProvisionerType(args.Provisioner),
|
||||
ParameterValues: parameters,
|
||||
}
|
||||
@@ -182,7 +186,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
|
||||
Cancel: func() error {
|
||||
return client.CancelTemplateVersion(cmd.Context(), version.ID)
|
||||
},
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) {
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
|
||||
return client.TemplateVersionLogsAfter(cmd.Context(), version.ID, before)
|
||||
},
|
||||
})
|
||||
|
||||
@@ -49,7 +49,7 @@ func TestTemplateList(t *testing.T) {
|
||||
})
|
||||
t.Run("NoTemplates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
client := coderdtest.New(t, &coderdtest.Options{})
|
||||
coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
cmd, root := clitest.New(t, "templates", "list")
|
||||
@@ -66,6 +66,8 @@ func TestTemplateList(t *testing.T) {
|
||||
|
||||
require.NoError(t, <-errC)
|
||||
|
||||
pty.ExpectMatch("No templates found in testuser! Create one:")
|
||||
pty.ExpectMatch("No templates found in")
|
||||
pty.ExpectMatch(coderdtest.FirstUserParams.Username)
|
||||
pty.ExpectMatch("Create one:")
|
||||
})
|
||||
}
|
||||
|
||||
+1
-1
@@ -66,7 +66,7 @@ func templatePull() *cobra.Command {
|
||||
latest := versions[0]
|
||||
|
||||
// Download the tar archive.
|
||||
raw, ctype, err := client.Download(ctx, latest.Job.StorageSource)
|
||||
raw, ctype, err := client.Download(ctx, latest.Job.FileID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("download template: %w", err)
|
||||
}
|
||||
|
||||
+4
-1
@@ -19,6 +19,7 @@ import (
|
||||
func templatePush() *cobra.Command {
|
||||
var (
|
||||
directory string
|
||||
versionName string
|
||||
provisioner string
|
||||
parameterFile string
|
||||
alwaysPrompt bool
|
||||
@@ -75,10 +76,11 @@ func templatePush() *cobra.Command {
|
||||
spin.Stop()
|
||||
|
||||
job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{
|
||||
Name: versionName,
|
||||
Client: client,
|
||||
Organization: organization,
|
||||
Provisioner: database.ProvisionerType(provisioner),
|
||||
FileHash: resp.Hash,
|
||||
FileID: resp.ID,
|
||||
ParameterFile: parameterFile,
|
||||
Template: &template,
|
||||
ReuseParameters: !alwaysPrompt,
|
||||
@@ -107,6 +109,7 @@ func templatePush() *cobra.Command {
|
||||
cmd.Flags().StringVarP(&directory, "directory", "d", currentDirectory, "Specify the directory to create from")
|
||||
cmd.Flags().StringVarP(&provisioner, "test.provisioner", "", "terraform", "Customize the provisioner backend")
|
||||
cmd.Flags().StringVarP(¶meterFile, "parameter-file", "", "", "Specify a file path with parameter values.")
|
||||
cmd.Flags().StringVarP(&versionName, "name", "", "", "Specify a name for the new template version. It will be automatically generated if not provided.")
|
||||
cmd.Flags().BoolVar(&alwaysPrompt, "always-prompt", false, "Always prompt all parameters. Does not pull parameter values from active template version")
|
||||
cliui.AllowSkipPrompt(cmd)
|
||||
// This is for testing!
|
||||
|
||||
@@ -122,7 +122,7 @@ func TestTemplatePush(t *testing.T) {
|
||||
Parse: echo.ParseComplete,
|
||||
Provision: echo.ProvisionComplete,
|
||||
})
|
||||
cmd, root := clitest.New(t, "templates", "push", template.Name, "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho))
|
||||
cmd, root := clitest.New(t, "templates", "push", template.Name, "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--name", "example")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
pty := ptytest.New(t)
|
||||
cmd.SetIn(pty.Input())
|
||||
@@ -153,6 +153,7 @@ func TestTemplatePush(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, templateVersions, 2)
|
||||
assert.NotEqual(t, template.ActiveVersionID, templateVersions[1].ID)
|
||||
require.Equal(t, "example", templateVersions[1].Name)
|
||||
})
|
||||
|
||||
t.Run("UseWorkingDir", func(t *testing.T) {
|
||||
|
||||
+6
-8
@@ -1,7 +1,6 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -14,7 +13,8 @@ import (
|
||||
func templates() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "templates",
|
||||
Short: "Create, manage, and deploy templates",
|
||||
Short: "Manage templates",
|
||||
Long: "Templates are written in standard Terraform and describe the infrastructure for workspaces",
|
||||
Aliases: []string{"template"},
|
||||
Example: formatExamples(
|
||||
example{
|
||||
@@ -30,6 +30,9 @@ func templates() *cobra.Command {
|
||||
Command: "coder templates push my-template",
|
||||
},
|
||||
),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmd.Help()
|
||||
},
|
||||
}
|
||||
cmd.AddCommand(
|
||||
templateCreate(),
|
||||
@@ -64,11 +67,6 @@ type templateTableRow struct {
|
||||
func displayTemplates(filterColumns []string, templates ...codersdk.Template) (string, error) {
|
||||
rows := make([]templateTableRow, len(templates))
|
||||
for i, template := range templates {
|
||||
suffix := ""
|
||||
if template.WorkspaceOwnerCount != 1 {
|
||||
suffix = "s"
|
||||
}
|
||||
|
||||
rows[i] = templateTableRow{
|
||||
Name: template.Name,
|
||||
CreatedAt: template.CreatedAt.Format("January 2, 2006"),
|
||||
@@ -76,7 +74,7 @@ func displayTemplates(filterColumns []string, templates ...codersdk.Template) (s
|
||||
OrganizationID: template.OrganizationID,
|
||||
Provisioner: template.Provisioner,
|
||||
ActiveVersionID: template.ActiveVersionID,
|
||||
UsedBy: cliui.Styles.Fuchsia.Render(fmt.Sprintf("%d developer%s", template.WorkspaceOwnerCount, suffix)),
|
||||
UsedBy: cliui.Styles.Fuchsia.Render(formatActiveDevelopers(template.ActiveUserCount)),
|
||||
MaxTTL: (time.Duration(template.MaxTTLMillis) * time.Millisecond),
|
||||
MinAutostartInterval: (time.Duration(template.MinAutostartIntervalMillis) * time.Millisecond),
|
||||
}
|
||||
|
||||
@@ -24,6 +24,9 @@ func templateVersions() *cobra.Command {
|
||||
Command: "coder templates versions list my-template",
|
||||
},
|
||||
),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmd.Help()
|
||||
},
|
||||
}
|
||||
cmd.AddCommand(
|
||||
templateVersionsList(),
|
||||
|
||||
+158
@@ -0,0 +1,158 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func tokens() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "tokens",
|
||||
Short: "Manage personal access tokens",
|
||||
Long: "Tokens are used to authenticate automated clients to Coder.",
|
||||
Aliases: []string{"token"},
|
||||
Example: formatExamples(
|
||||
example{
|
||||
Description: "Create a token for automation",
|
||||
Command: "coder tokens create",
|
||||
},
|
||||
example{
|
||||
Description: "List your tokens",
|
||||
Command: "coder tokens ls",
|
||||
},
|
||||
example{
|
||||
Description: "Remove a token by ID",
|
||||
Command: "coder tokens rm WuoWs4ZsMX",
|
||||
},
|
||||
),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmd.Help()
|
||||
},
|
||||
}
|
||||
cmd.AddCommand(
|
||||
createToken(),
|
||||
listTokens(),
|
||||
removeToken(),
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func createToken() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "create",
|
||||
Short: "Create a tokens",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create codersdk client: %w", err)
|
||||
}
|
||||
|
||||
res, err := client.CreateToken(cmd.Context(), codersdk.Me, codersdk.CreateTokenRequest{})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tokens: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println(cliui.Styles.Wrap.Render(
|
||||
"Here is your token. 🪄",
|
||||
))
|
||||
cmd.Println()
|
||||
cmd.Println(cliui.Styles.Code.Render(strings.TrimSpace(res.Key)))
|
||||
cmd.Println()
|
||||
cmd.Println(cliui.Styles.Wrap.Render(
|
||||
fmt.Sprintf("You can use this token by setting the --%s CLI flag, the %s environment variable, or the %q HTTP header.", varToken, envSessionToken, codersdk.SessionCustomHeader),
|
||||
))
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
type tokenRow struct {
|
||||
ID string `table:"ID"`
|
||||
LastUsed time.Time `table:"Last Used"`
|
||||
ExpiresAt time.Time `table:"Expires At"`
|
||||
CreatedAt time.Time `table:"Created At"`
|
||||
}
|
||||
|
||||
func listTokens() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List tokens",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create codersdk client: %w", err)
|
||||
}
|
||||
|
||||
keys, err := client.GetTokens(cmd.Context(), codersdk.Me)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create tokens: %w", err)
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
cmd.Println(cliui.Styles.Wrap.Render(
|
||||
"No tokens found.",
|
||||
))
|
||||
}
|
||||
|
||||
var rows []tokenRow
|
||||
for _, key := range keys {
|
||||
rows = append(rows, tokenRow{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
CreatedAt: key.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
out, err := cliui.DisplayTable(rows, "", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintln(cmd.OutOrStdout(), out)
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func removeToken() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "remove [id]",
|
||||
Aliases: []string{"rm"},
|
||||
Short: "Delete a token",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create codersdk client: %w", err)
|
||||
}
|
||||
|
||||
err = client.DeleteAPIKey(cmd.Context(), codersdk.Me, args[0])
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete api key: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println(cliui.Styles.Wrap.Render(
|
||||
"Token has been deleted.",
|
||||
))
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cli/clitest"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
)
|
||||
|
||||
func TestTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// helpful empty response
|
||||
cmd, root := clitest.New(t, "tokens", "ls")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf := new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
err := cmd.Execute()
|
||||
require.NoError(t, err)
|
||||
res := buf.String()
|
||||
require.Contains(t, res, "tokens found")
|
||||
|
||||
cmd, root = clitest.New(t, "tokens", "create")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
err = cmd.Execute()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
// find API key in format "XXXXXXXXXX-XXXXXXXXXXXXXXXXXXXXXX"
|
||||
r := regexp.MustCompile("[a-zA-Z0-9]{10}-[a-zA-Z0-9]{22}")
|
||||
require.Regexp(t, r, res)
|
||||
key := r.FindString(res)
|
||||
id := key[:10]
|
||||
|
||||
cmd, root = clitest.New(t, "tokens", "ls")
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
err = cmd.Execute()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "ID")
|
||||
require.Contains(t, res, "EXPIRES AT")
|
||||
require.Contains(t, res, "CREATED AT")
|
||||
require.Contains(t, res, "LAST USED")
|
||||
require.Contains(t, res, id)
|
||||
|
||||
cmd, root = clitest.New(t, "tokens", "rm", id)
|
||||
clitest.SetupConfig(t, client, root)
|
||||
buf = new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
err = cmd.Execute()
|
||||
require.NoError(t, err)
|
||||
res = buf.String()
|
||||
require.NotEmpty(t, res)
|
||||
require.Contains(t, res, "deleted")
|
||||
}
|
||||
+3
-2
@@ -20,7 +20,7 @@ func update() *cobra.Command {
|
||||
Annotations: workspaceCommand,
|
||||
Use: "update <workspace>",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Short: "Update a workspace to the latest template version",
|
||||
Short: "Update a workspace",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
client, err := CreateClient(cmd)
|
||||
if err != nil {
|
||||
@@ -66,10 +66,11 @@ func update() *cobra.Command {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logs, err := client.WorkspaceBuildLogsAfter(cmd.Context(), build.ID, before)
|
||||
logs, closer, err := client.WorkspaceBuildLogsAfter(cmd.Context(), build.ID, before)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer closer.Close()
|
||||
for {
|
||||
log, ok := <-logs
|
||||
if !ok {
|
||||
|
||||
+4
-1
@@ -8,9 +8,12 @@ import (
|
||||
|
||||
func users() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Short: "Create, remove, and list users",
|
||||
Short: "Manage users",
|
||||
Use: "users",
|
||||
Aliases: []string{"user"},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return cmd.Help()
|
||||
},
|
||||
}
|
||||
cmd.AddCommand(
|
||||
userCreate(),
|
||||
|
||||
+17
@@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -175,3 +176,19 @@ func parseTime(s string) (time.Time, error) {
|
||||
}
|
||||
return time.Time{}, errInvalidTimeFormat
|
||||
}
|
||||
|
||||
func formatActiveDevelopers(n int) string {
|
||||
developerText := "developer"
|
||||
if n != 1 {
|
||||
developerText = "developers"
|
||||
}
|
||||
|
||||
var nStr string
|
||||
if n < 0 {
|
||||
nStr = "-"
|
||||
} else {
|
||||
nStr = strconv.Itoa(n)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s active %s", nStr, developerText)
|
||||
}
|
||||
|
||||
+3
-2
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -100,7 +101,7 @@ func main() {
|
||||
Fetch: func() (codersdk.ProvisionerJob, error) {
|
||||
return job, nil
|
||||
},
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) {
|
||||
Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
|
||||
logs := make(chan codersdk.ProvisionerJobLog)
|
||||
go func() {
|
||||
defer close(logs)
|
||||
@@ -143,7 +144,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
}()
|
||||
return logs, nil
|
||||
return logs, io.NopCloser(strings.NewReader("")), nil
|
||||
},
|
||||
Cancel: func() error {
|
||||
job.Status = codersdk.ProvisionerJobCanceling
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# Run "coder server --help" for flag information.
|
||||
# Coder must be reachable from an external URL for users and workspaces to connect.
|
||||
# e.g. https://coder.example.com
|
||||
CODER_ACCESS_URL=
|
||||
|
||||
CODER_ADDRESS=
|
||||
CODER_PG_CONNECTION_URL=
|
||||
CODER_TLS_CERT_FILE=
|
||||
CODER_TLS_ENABLE=
|
||||
CODER_TLS_KEY_FILE=
|
||||
# Generate a unique *.try.coder.app access URL
|
||||
CODER_TUNNEL=false
|
||||
|
||||
# Run "coder server --help" for flag information.
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
)
|
||||
|
||||
// activityBumpWorkspace automatically bumps the workspace's auto-off timer
|
||||
// if it is set to expire soon.
|
||||
func activityBumpWorkspace(log slog.Logger, db database.Store, workspace database.Workspace) {
|
||||
// We set a short timeout so if the app is under load, these
|
||||
// low priority operations fail first.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
|
||||
defer cancel()
|
||||
|
||||
err := db.InTx(func(s database.Store) error {
|
||||
build, err := s.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return xerrors.Errorf("get latest workspace build: %w", err)
|
||||
}
|
||||
|
||||
job, err := s.GetProvisionerJobByID(ctx, build.JobID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get provisioner job: %w", err)
|
||||
}
|
||||
|
||||
if build.Transition != database.WorkspaceTransitionStart || !job.CompletedAt.Valid {
|
||||
return nil
|
||||
}
|
||||
|
||||
if build.Deadline.IsZero() {
|
||||
// Workspace shutdown is manual
|
||||
return nil
|
||||
}
|
||||
|
||||
// We sent bumpThreshold slightly under bumpAmount to minimize DB writes.
|
||||
const (
|
||||
bumpAmount = time.Hour
|
||||
bumpThreshold = time.Hour - (time.Minute * 10)
|
||||
)
|
||||
|
||||
if !build.Deadline.Before(time.Now().Add(bumpThreshold)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
newDeadline := database.Now().Add(bumpAmount)
|
||||
|
||||
if err := s.UpdateWorkspaceBuildByID(ctx, database.UpdateWorkspaceBuildByIDParams{
|
||||
ID: build.ID,
|
||||
UpdatedAt: database.Now(),
|
||||
ProvisionerState: build.ProvisionerState,
|
||||
Deadline: newDeadline,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("update workspace build: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(
|
||||
ctx, "bump failed",
|
||||
slog.Error(err),
|
||||
slog.F("workspace_id", workspace.ID),
|
||||
)
|
||||
} else {
|
||||
log.Debug(
|
||||
ctx, "bumped deadline from activity",
|
||||
slog.F("workspace_id", workspace.ID),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestWorkspaceActivityBump(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
setupActivityTest := func(t *testing.T) (client *codersdk.Client, workspace codersdk.Workspace, assertBumped func(want bool)) {
|
||||
var ttlMillis int64 = 60 * 1000
|
||||
|
||||
client = coderdtest.New(t, &coderdtest.Options{
|
||||
AppHostname: proxyTestSubdomainRaw,
|
||||
IncludeProvisionerDaemon: true,
|
||||
AgentStatsRefreshInterval: time.Millisecond * 100,
|
||||
MetricsCacheRefreshInterval: time.Millisecond * 100,
|
||||
})
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
workspace = createWorkspaceWithApps(t, client, user.OrganizationID, 1234, func(cwr *codersdk.CreateWorkspaceRequest) {
|
||||
cwr.TTLMillis = &ttlMillis
|
||||
})
|
||||
|
||||
// Sanity-check that deadline is near.
|
||||
workspace, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
require.WithinDuration(t,
|
||||
time.Now().Add(time.Duration(ttlMillis)*time.Millisecond),
|
||||
workspace.LatestBuild.Deadline.Time, testutil.WaitShort,
|
||||
)
|
||||
firstDeadline := workspace.LatestBuild.Deadline.Time
|
||||
|
||||
_ = coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
|
||||
return client, workspace, func(want bool) {
|
||||
if !want {
|
||||
// It is difficult to test the absence of a call in a non-racey
|
||||
// way. In general, it is difficult for the API to generate
|
||||
// false positive activity since Agent networking event
|
||||
// is required. The Activity Bump behavior is also coupled with
|
||||
// Last Used, so it would be obvious to the user if we
|
||||
// are falsely recognizing activity.
|
||||
time.Sleep(testutil.IntervalMedium)
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, workspace.LatestBuild.Deadline.Time, firstDeadline)
|
||||
return
|
||||
}
|
||||
|
||||
// The Deadline bump occurs asynchronously.
|
||||
require.Eventuallyf(t,
|
||||
func() bool {
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
return workspace.LatestBuild.Deadline.Time != firstDeadline
|
||||
},
|
||||
testutil.WaitShort, testutil.IntervalFast,
|
||||
"deadline %v never updated", firstDeadline,
|
||||
)
|
||||
|
||||
require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, 3*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Dial", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, workspace, assertBumped := setupActivityTest(t)
|
||||
|
||||
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
|
||||
Logger: slogtest.Make(t, nil),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
sshConn, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
_ = sshConn.Close()
|
||||
|
||||
assertBumped(true)
|
||||
})
|
||||
|
||||
t.Run("NoBump", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, workspace, assertBumped := setupActivityTest(t)
|
||||
|
||||
// Benign operations like retrieving workspace must not
|
||||
// bump the deadline.
|
||||
_, err := client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertBumped(false)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,287 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tabbed/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
// Creates a new token API key that effectively doesn't expire.
|
||||
func (api *API) postToken(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceAPIKey.WithOwner(user.ID.String())) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
var createToken codersdk.CreateTokenRequest
|
||||
if !httpapi.Read(ctx, rw, r, &createToken) {
|
||||
return
|
||||
}
|
||||
|
||||
scope := database.APIKeyScopeAll
|
||||
if scope != "" {
|
||||
scope = database.APIKeyScope(createToken.Scope)
|
||||
}
|
||||
|
||||
// tokens last 100 years
|
||||
lifeTime := time.Hour * 876000
|
||||
cookie, err := api.createAPIKey(ctx, createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeToken,
|
||||
ExpiresAt: database.Now().Add(lifeTime),
|
||||
Scope: scope,
|
||||
LifetimeSeconds: int64(lifeTime.Seconds()),
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to create API key.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: cookie.Value})
|
||||
}
|
||||
|
||||
// Creates a new session key, used for logging in via the CLI.
|
||||
// DEPRECATED: use postToken instead.
|
||||
func (api *API) postAPIKey(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
if !api.Authorize(r, rbac.ActionCreate, rbac.ResourceAPIKey.WithOwner(user.ID.String())) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
lifeTime := time.Hour * 24 * 7
|
||||
cookie, err := api.createAPIKey(ctx, createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
// All api generated keys will last 1 week. Browser login tokens have
|
||||
// a shorter life.
|
||||
ExpiresAt: database.Now().Add(lifeTime),
|
||||
LifetimeSeconds: int64(lifeTime.Seconds()),
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to create API key.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// We intentionally do not set the cookie on the response here.
|
||||
// Setting the cookie will couple the browser sesion to the API
|
||||
// key we return here, meaning logging out of the website would
|
||||
// invalid your CLI key.
|
||||
httpapi.Write(ctx, rw, http.StatusCreated, codersdk.GenerateAPIKeyResponse{Key: cookie.Value})
|
||||
}
|
||||
|
||||
func (api *API) apiKey(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
)
|
||||
|
||||
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceAPIKey.WithOwner(user.ID.String())) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
keyID := chi.URLParam(r, "keyid")
|
||||
key, err := api.Database.GetAPIKeyByID(ctx, keyID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API key.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, convertAPIKey(key))
|
||||
}
|
||||
|
||||
func (api *API) tokens(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
)
|
||||
|
||||
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceAPIKey.WithOwner(user.ID.String())) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := api.Database.GetAPIKeysByLoginType(ctx, database.LoginTypeToken)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, []codersdk.APIKey{})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching API keys.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var apiKeys []codersdk.APIKey
|
||||
for _, key := range keys {
|
||||
apiKeys = append(apiKeys, convertAPIKey(key))
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, apiKeys)
|
||||
}
|
||||
|
||||
func (api *API) deleteAPIKey(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
user = httpmw.UserParam(r)
|
||||
)
|
||||
|
||||
if !api.Authorize(r, rbac.ActionDelete, rbac.ResourceAPIKey.WithOwner(user.ID.String())) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
|
||||
keyID := chi.URLParam(r, "keyid")
|
||||
err := api.Database.DeleteAPIKeyByID(ctx, keyID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error deleting API key.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
// Generates a new ID and secret for an API key.
|
||||
func generateAPIKeyIDSecret() (id string, secret string, err error) {
|
||||
// Length of an API Key ID.
|
||||
id, err = cryptorand.String(10)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
// Length of an API Key secret.
|
||||
secret, err = cryptorand.String(22)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return id, secret, nil
|
||||
}
|
||||
|
||||
type createAPIKeyParams struct {
|
||||
UserID uuid.UUID
|
||||
RemoteAddr string
|
||||
LoginType database.LoginType
|
||||
|
||||
// Optional.
|
||||
ExpiresAt time.Time
|
||||
LifetimeSeconds int64
|
||||
Scope database.APIKeyScope
|
||||
}
|
||||
|
||||
func (api *API) createAPIKey(ctx context.Context, params createAPIKeyParams) (*http.Cookie, error) {
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("generate API key: %w", err)
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
// Default expires at to now+lifetime, or just 24hrs if not set
|
||||
if params.ExpiresAt.IsZero() {
|
||||
if params.LifetimeSeconds != 0 {
|
||||
params.ExpiresAt = database.Now().Add(time.Duration(params.LifetimeSeconds) * time.Second)
|
||||
} else {
|
||||
params.ExpiresAt = database.Now().Add(24 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
host, _, _ := net.SplitHostPort(params.RemoteAddr)
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
ip = net.IPv4(0, 0, 0, 0)
|
||||
}
|
||||
bitlen := len(ip) * 8
|
||||
|
||||
scope := database.APIKeyScopeAll
|
||||
if params.Scope != "" {
|
||||
scope = params.Scope
|
||||
}
|
||||
switch scope {
|
||||
case database.APIKeyScopeAll, database.APIKeyScopeApplicationConnect:
|
||||
default:
|
||||
return nil, xerrors.Errorf("invalid API key scope: %q", scope)
|
||||
}
|
||||
|
||||
key, err := api.Database.InsertAPIKey(ctx, database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
UserID: params.UserID,
|
||||
LifetimeSeconds: params.LifetimeSeconds,
|
||||
IPAddress: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: ip,
|
||||
Mask: net.CIDRMask(bitlen, bitlen),
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
// Make sure in UTC time for common time zone
|
||||
ExpiresAt: params.ExpiresAt.UTC(),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: params.LoginType,
|
||||
Scope: scope,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert API key: %w", err)
|
||||
}
|
||||
|
||||
api.Telemetry.Report(&telemetry.Snapshot{
|
||||
APIKeys: []telemetry.APIKey{telemetry.ConvertAPIKey(key)},
|
||||
})
|
||||
|
||||
// This format is consumed by the APIKey middleware.
|
||||
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
return &http.Cookie{
|
||||
Name: codersdk.SessionTokenKey,
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: api.SecureAuthCookie,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("CRUD", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
keys, err := client.GetTokens(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, keys)
|
||||
|
||||
res, err := client.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(res.Key), 2)
|
||||
|
||||
keys, err = client.GetTokens(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(keys), 1)
|
||||
require.Contains(t, res.Key, keys[0].ID)
|
||||
// expires_at must be greater than 50 years
|
||||
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*438300))
|
||||
require.Equal(t, codersdk.APIKeyScopeAll, keys[0].Scope)
|
||||
|
||||
// no update
|
||||
|
||||
err = client.DeleteAPIKey(ctx, codersdk.Me, keys[0].ID)
|
||||
require.NoError(t, err)
|
||||
keys, err = client.GetTokens(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, keys)
|
||||
})
|
||||
|
||||
t.Run("Scoped", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
res, err := client.CreateToken(ctx, codersdk.Me, codersdk.CreateTokenRequest{
|
||||
Scope: codersdk.APIKeyScopeApplicationConnect,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(res.Key), 2)
|
||||
|
||||
keys, err := client.GetTokens(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(keys), 1)
|
||||
require.Contains(t, res.Key, keys[0].ID)
|
||||
// expires_at must be greater than 50 years
|
||||
require.Greater(t, keys[0].ExpiresAt, time.Now().Add(time.Hour*438300))
|
||||
require.Equal(t, keys[0].Scope, codersdk.APIKeyScopeApplicationConnect)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
res, err := client.CreateAPIKey(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, len(res.Key), 2)
|
||||
}
|
||||
+148
-10
@@ -2,9 +2,12 @@ package coderd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -18,27 +21,42 @@ import (
|
||||
)
|
||||
|
||||
func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceAuditLog) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
page, ok := parsePagination(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
queryStr := r.URL.Query().Get("q")
|
||||
filter, errs := auditSearchQuery(queryStr)
|
||||
if len(errs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid audit search query.",
|
||||
Validations: errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
dblogs, err := api.Database.GetAuditLogsOffset(ctx, database.GetAuditLogsOffsetParams{
|
||||
Offset: int32(page.Offset),
|
||||
Limit: int32(page.Limit),
|
||||
Offset: int32(page.Offset),
|
||||
Limit: int32(page.Limit),
|
||||
ResourceType: filter.ResourceType,
|
||||
ResourceID: filter.ResourceID,
|
||||
Action: filter.Action,
|
||||
Username: filter.Username,
|
||||
Email: filter.Email,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
|
||||
AuditLogs: convertAuditLogs(dblogs),
|
||||
})
|
||||
}
|
||||
@@ -50,13 +68,29 @@ func (api *API) auditLogCount(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
count, err := api.Database.GetAuditLogCount(ctx)
|
||||
queryStr := r.URL.Query().Get("q")
|
||||
filter, errs := auditSearchQuery(queryStr)
|
||||
if len(errs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid audit search query.",
|
||||
Validations: errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
count, err := api.Database.GetAuditLogCount(ctx, database.GetAuditLogCountParams{
|
||||
ResourceType: filter.ResourceType,
|
||||
ResourceID: filter.ResourceID,
|
||||
Action: filter.Action,
|
||||
Username: filter.Username,
|
||||
Email: filter.Email,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(rw, http.StatusOK, codersdk.AuditLogCountResponse{
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogCountResponse{
|
||||
Count: count,
|
||||
})
|
||||
}
|
||||
@@ -96,16 +130,30 @@ func (api *API) generateFakeAuditLog(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
var params codersdk.CreateTestAuditLogRequest
|
||||
if !httpapi.Read(ctx, rw, r, ¶ms) {
|
||||
return
|
||||
}
|
||||
if params.Action == "" {
|
||||
params.Action = codersdk.AuditActionWrite
|
||||
}
|
||||
if params.ResourceType == "" {
|
||||
params.ResourceType = codersdk.ResourceTypeUser
|
||||
}
|
||||
if params.ResourceID == uuid.Nil {
|
||||
params.ResourceID = uuid.New()
|
||||
}
|
||||
|
||||
_, err = api.Database.InsertAuditLog(ctx, database.InsertAuditLogParams{
|
||||
ID: uuid.New(),
|
||||
Time: time.Now(),
|
||||
UserID: user.ID,
|
||||
Ip: ipNet,
|
||||
UserAgent: r.UserAgent(),
|
||||
ResourceType: database.ResourceTypeUser,
|
||||
ResourceID: user.ID,
|
||||
ResourceType: database.ResourceType(params.ResourceType),
|
||||
ResourceID: params.ResourceID,
|
||||
ResourceTarget: user.Username,
|
||||
Action: database.AuditActionWrite,
|
||||
Action: database.AuditAction(params.Action),
|
||||
Diff: diff,
|
||||
StatusCode: http.StatusOK,
|
||||
AdditionalFields: []byte("{}"),
|
||||
@@ -167,7 +215,97 @@ func convertAuditLog(dblog database.GetAuditLogsOffsetRow) codersdk.AuditLog {
|
||||
Diff: diff,
|
||||
StatusCode: dblog.StatusCode,
|
||||
AdditionalFields: dblog.AdditionalFields,
|
||||
Description: "",
|
||||
Description: auditLogDescription(dblog),
|
||||
User: user,
|
||||
}
|
||||
}
|
||||
|
||||
func auditLogDescription(alog database.GetAuditLogsOffsetRow) string {
|
||||
str := fmt.Sprintf("{user} %s %s",
|
||||
codersdk.AuditAction(alog.Action).FriendlyString(),
|
||||
codersdk.ResourceType(alog.ResourceType).FriendlyString(),
|
||||
)
|
||||
|
||||
// We don't display the name for git ssh keys. It's fairly long and doesn't
|
||||
// make too much sense to display.
|
||||
if alog.ResourceType != database.ResourceTypeGitSshKey {
|
||||
str += " {target}"
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// auditSearchQuery takes a query string and returns the auditLog filter.
|
||||
// It also can return the list of validation errors to return to the api.
|
||||
func auditSearchQuery(query string) (database.GetAuditLogsOffsetParams, []codersdk.ValidationError) {
|
||||
searchParams := make(url.Values)
|
||||
if query == "" {
|
||||
// No filter
|
||||
return database.GetAuditLogsOffsetParams{}, nil
|
||||
}
|
||||
query = strings.ToLower(query)
|
||||
// Because we do this in 2 passes, we want to maintain quotes on the first
|
||||
// pass.Further splitting occurs on the second pass and quotes will be
|
||||
// dropped.
|
||||
elements := splitQueryParameterByDelimiter(query, ' ', true)
|
||||
for _, element := range elements {
|
||||
parts := splitQueryParameterByDelimiter(element, ':', false)
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
// No key:value pair.
|
||||
searchParams.Set("resource_type", parts[0])
|
||||
case 2:
|
||||
searchParams.Set(parts[0], parts[1])
|
||||
default:
|
||||
return database.GetAuditLogsOffsetParams{}, []codersdk.ValidationError{
|
||||
{Field: "q", Detail: fmt.Sprintf("Query element %q can only contain 1 ':'", element)},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Using the query param parser here just returns consistent errors with
|
||||
// other parsing.
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
filter := database.GetAuditLogsOffsetParams{
|
||||
ResourceType: resourceTypeFromString(parser.String(searchParams, "", "resource_type")),
|
||||
ResourceID: parser.UUID(searchParams, uuid.Nil, "resource_id"),
|
||||
Action: actionFromString(parser.String(searchParams, "", "action")),
|
||||
Username: parser.String(searchParams, "", "username"),
|
||||
Email: parser.String(searchParams, "", "email"),
|
||||
}
|
||||
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
func resourceTypeFromString(resourceTypeString string) string {
|
||||
switch codersdk.ResourceType(resourceTypeString) {
|
||||
case codersdk.ResourceTypeOrganization:
|
||||
return resourceTypeString
|
||||
case codersdk.ResourceTypeTemplate:
|
||||
return resourceTypeString
|
||||
case codersdk.ResourceTypeTemplateVersion:
|
||||
return resourceTypeString
|
||||
case codersdk.ResourceTypeUser:
|
||||
return resourceTypeString
|
||||
case codersdk.ResourceTypeWorkspace:
|
||||
return resourceTypeString
|
||||
case codersdk.ResourceTypeGitSSHKey:
|
||||
return resourceTypeString
|
||||
case codersdk.ResourceTypeAPIKey:
|
||||
return resourceTypeString
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func actionFromString(actionString string) string {
|
||||
switch codersdk.AuditAction(actionString) {
|
||||
case codersdk.AuditActionCreate:
|
||||
return actionString
|
||||
case codersdk.AuditActionWrite:
|
||||
return actionString
|
||||
case codersdk.AuditActionDelete:
|
||||
return actionString
|
||||
default:
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
+20
-1
@@ -21,4 +21,23 @@ func (nop) Export(context.Context, database.AuditLog) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nop) diff(any, any) Map { return Map{} }
|
||||
func (nop) diff(any, any) Map {
|
||||
return Map{}
|
||||
}
|
||||
|
||||
func NewMock() *MockAuditor {
|
||||
return &MockAuditor{}
|
||||
}
|
||||
|
||||
type MockAuditor struct {
|
||||
AuditLogs []database.AuditLog
|
||||
}
|
||||
|
||||
func (a *MockAuditor) Export(_ context.Context, alog database.AuditLog) error {
|
||||
a.AuditLogs = append(a.AuditLogs, alog)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*MockAuditor) diff(any, any) Map {
|
||||
return Map{}
|
||||
}
|
||||
|
||||
+106
-25
@@ -3,6 +3,7 @@ package audit
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
@@ -11,20 +12,16 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
type RequestParams struct {
|
||||
Audit Auditor
|
||||
Log slog.Logger
|
||||
|
||||
Request *http.Request
|
||||
ResourceID uuid.UUID
|
||||
ResourceTarget string
|
||||
Action database.AuditAction
|
||||
ResourceType database.ResourceType
|
||||
Actor uuid.UUID
|
||||
Request *http.Request
|
||||
Action database.AuditAction
|
||||
}
|
||||
|
||||
type Request[T Auditable] struct {
|
||||
@@ -34,13 +31,70 @@ type Request[T Auditable] struct {
|
||||
New T
|
||||
}
|
||||
|
||||
func ResourceTarget[T Auditable](tgt T) string {
|
||||
switch typed := any(tgt).(type) {
|
||||
case database.Organization:
|
||||
return typed.Name
|
||||
case database.Template:
|
||||
return typed.Name
|
||||
case database.TemplateVersion:
|
||||
return typed.Name
|
||||
case database.User:
|
||||
return typed.Username
|
||||
case database.Workspace:
|
||||
return typed.Name
|
||||
case database.GitSSHKey:
|
||||
return typed.PublicKey
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T", tgt))
|
||||
}
|
||||
}
|
||||
|
||||
func ResourceID[T Auditable](tgt T) uuid.UUID {
|
||||
switch typed := any(tgt).(type) {
|
||||
case database.Organization:
|
||||
return typed.ID
|
||||
case database.Template:
|
||||
return typed.ID
|
||||
case database.TemplateVersion:
|
||||
return typed.ID
|
||||
case database.User:
|
||||
return typed.ID
|
||||
case database.Workspace:
|
||||
return typed.ID
|
||||
case database.GitSSHKey:
|
||||
return typed.UserID
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T", tgt))
|
||||
}
|
||||
}
|
||||
|
||||
func ResourceType[T Auditable](tgt T) database.ResourceType {
|
||||
switch any(tgt).(type) {
|
||||
case database.Organization:
|
||||
return database.ResourceTypeOrganization
|
||||
case database.Template:
|
||||
return database.ResourceTypeTemplate
|
||||
case database.TemplateVersion:
|
||||
return database.ResourceTypeTemplateVersion
|
||||
case database.User:
|
||||
return database.ResourceTypeUser
|
||||
case database.Workspace:
|
||||
return database.ResourceTypeWorkspace
|
||||
case database.GitSSHKey:
|
||||
return database.ResourceTypeGitSshKey
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown resource %T", tgt))
|
||||
}
|
||||
}
|
||||
|
||||
// InitRequest initializes an audit log for a request. It returns a function
|
||||
// that should be deferred, causing the audit log to be committed when the
|
||||
// handler returns.
|
||||
func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request[T], func()) {
|
||||
sw, ok := w.(*httpapi.StatusWriter)
|
||||
sw, ok := w.(*tracing.StatusWriter)
|
||||
if !ok {
|
||||
panic("dev error: http.ResponseWriter is not *httpapi.StatusWriter")
|
||||
panic("dev error: http.ResponseWriter is not *tracing.StatusWriter")
|
||||
}
|
||||
|
||||
req := &Request[T]{
|
||||
@@ -49,36 +103,63 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
|
||||
|
||||
return req, func() {
|
||||
ctx := context.Background()
|
||||
logCtx := p.Request.Context()
|
||||
|
||||
diff := Diff(p.Audit, req.Old, req.New)
|
||||
diffRaw, _ := json.Marshal(diff)
|
||||
// If no resources were provided, there's nothing we can audit.
|
||||
if ResourceID(req.Old) == uuid.Nil && ResourceID(req.New) == uuid.Nil {
|
||||
return
|
||||
}
|
||||
|
||||
var diffRaw = []byte("{}")
|
||||
// Only generate diffs if the request succeeded.
|
||||
if sw.Status < 400 {
|
||||
diff := Diff(p.Audit, req.Old, req.New)
|
||||
|
||||
var err error
|
||||
diffRaw, err = json.Marshal(diff)
|
||||
if err != nil {
|
||||
p.Log.Warn(logCtx, "marshal diff", slog.Error(err))
|
||||
diffRaw = []byte("{}")
|
||||
}
|
||||
}
|
||||
|
||||
ip, err := parseIP(p.Request.RemoteAddr)
|
||||
if err != nil {
|
||||
p.Log.Warn(ctx, "parse ip", slog.Error(err))
|
||||
p.Log.Warn(logCtx, "parse ip", slog.Error(err))
|
||||
}
|
||||
|
||||
err = p.Audit.Export(ctx, database.AuditLog{
|
||||
ID: uuid.New(),
|
||||
Time: database.Now(),
|
||||
UserID: p.Actor,
|
||||
Ip: ip,
|
||||
UserAgent: p.Request.UserAgent(),
|
||||
ResourceType: p.ResourceType,
|
||||
ResourceID: p.ResourceID,
|
||||
ResourceTarget: p.ResourceTarget,
|
||||
Action: p.Action,
|
||||
Diff: diffRaw,
|
||||
StatusCode: int32(sw.Status),
|
||||
RequestID: httpmw.RequestID(p.Request),
|
||||
ID: uuid.New(),
|
||||
Time: database.Now(),
|
||||
UserID: httpmw.APIKey(p.Request).UserID,
|
||||
Ip: ip,
|
||||
UserAgent: p.Request.UserAgent(),
|
||||
ResourceType: either(req.Old, req.New, ResourceType[T]),
|
||||
ResourceID: either(req.Old, req.New, ResourceID[T]),
|
||||
ResourceTarget: either(req.Old, req.New, ResourceTarget[T]),
|
||||
Action: p.Action,
|
||||
Diff: diffRaw,
|
||||
StatusCode: int32(sw.Status),
|
||||
RequestID: httpmw.RequestID(p.Request),
|
||||
AdditionalFields: json.RawMessage("{}"),
|
||||
})
|
||||
if err != nil {
|
||||
p.Log.Error(ctx, "export audit log", slog.Error(err))
|
||||
p.Log.Error(logCtx, "export audit log", slog.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func either[T Auditable, R any](old, new T, fn func(T) R) R {
|
||||
if ResourceID(new) != uuid.Nil {
|
||||
return fn(new)
|
||||
} else if ResourceID(old) != uuid.Nil {
|
||||
return fn(old)
|
||||
} else {
|
||||
panic("both old and new are nil")
|
||||
}
|
||||
}
|
||||
|
||||
func parseIP(ipStr string) (pqtype.Inet, error) {
|
||||
var err error
|
||||
|
||||
|
||||
+126
-3
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
@@ -20,16 +21,138 @@ func TestAuditLogs(t *testing.T) {
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
err := client.CreateTestAuditLog(ctx)
|
||||
err := client.CreateTestAuditLog(ctx, codersdk.CreateTestAuditLogRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := client.AuditLogCount(ctx)
|
||||
count, err := client.AuditLogCount(ctx, codersdk.AuditLogCountRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
alogs, err := client.AuditLogs(ctx, codersdk.Pagination{Limit: 1})
|
||||
alogs, err := client.AuditLogs(ctx, codersdk.AuditLogsRequest{
|
||||
Pagination: codersdk.Pagination{
|
||||
Limit: 1,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(1), count.Count)
|
||||
require.Len(t, alogs.AuditLogs, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuditLogsFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Filter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
userResourceID := uuid.New()
|
||||
|
||||
// Create two logs with "Create"
|
||||
err := client.CreateTestAuditLog(ctx, codersdk.CreateTestAuditLogRequest{
|
||||
Action: codersdk.AuditActionCreate,
|
||||
ResourceType: codersdk.ResourceTypeTemplate,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = client.CreateTestAuditLog(ctx, codersdk.CreateTestAuditLogRequest{
|
||||
Action: codersdk.AuditActionCreate,
|
||||
ResourceType: codersdk.ResourceTypeUser,
|
||||
ResourceID: userResourceID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create one log with "Delete"
|
||||
err = client.CreateTestAuditLog(ctx, codersdk.CreateTestAuditLogRequest{
|
||||
Action: codersdk.AuditActionDelete,
|
||||
ResourceType: codersdk.ResourceTypeUser,
|
||||
ResourceID: userResourceID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test cases
|
||||
testCases := []struct {
|
||||
Name string
|
||||
SearchQuery string
|
||||
ExpectedResult int
|
||||
}{
|
||||
{
|
||||
Name: "FilterByCreateAction",
|
||||
SearchQuery: "action:create",
|
||||
ExpectedResult: 2,
|
||||
},
|
||||
{
|
||||
Name: "FilterByDeleteAction",
|
||||
SearchQuery: "action:delete",
|
||||
ExpectedResult: 1,
|
||||
},
|
||||
{
|
||||
Name: "FilterByUserResourceType",
|
||||
SearchQuery: "resource_type:user",
|
||||
ExpectedResult: 2,
|
||||
},
|
||||
{
|
||||
Name: "FilterByTemplateResourceType",
|
||||
SearchQuery: "resource_type:template",
|
||||
ExpectedResult: 1,
|
||||
},
|
||||
{
|
||||
Name: "FilterByEmail",
|
||||
SearchQuery: "email:" + coderdtest.FirstUserParams.Email,
|
||||
ExpectedResult: 3,
|
||||
},
|
||||
{
|
||||
Name: "FilterByUsername",
|
||||
SearchQuery: "username:" + coderdtest.FirstUserParams.Username,
|
||||
ExpectedResult: 3,
|
||||
},
|
||||
{
|
||||
Name: "FilterByResourceID",
|
||||
SearchQuery: "resource_id:" + userResourceID.String(),
|
||||
ExpectedResult: 2,
|
||||
},
|
||||
{
|
||||
Name: "FilterInvalidSingleValue",
|
||||
SearchQuery: "invalid",
|
||||
ExpectedResult: 3,
|
||||
},
|
||||
{
|
||||
Name: "FilterWithInvalidResourceType",
|
||||
SearchQuery: "resource_type:invalid",
|
||||
ExpectedResult: 3,
|
||||
},
|
||||
{
|
||||
Name: "FilterWithInvalidAction",
|
||||
SearchQuery: "action:invalid",
|
||||
ExpectedResult: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
// Test filtering
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
auditLogs, err := client.AuditLogs(ctx, codersdk.AuditLogsRequest{
|
||||
SearchQuery: testCase.SearchQuery,
|
||||
Pagination: codersdk.Pagination{
|
||||
Limit: 25,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "fetch audit logs")
|
||||
require.Len(t, auditLogs.AuditLogs, testCase.ExpectedResult, "expected audit logs returned")
|
||||
})
|
||||
|
||||
// Test count filtering
|
||||
t.Run("GetCount"+testCase.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
response, err := client.AuditLogCount(ctx, codersdk.AuditLogCountRequest{
|
||||
SearchQuery: testCase.SearchQuery,
|
||||
})
|
||||
require.NoError(t, err, "fetch audit logs count")
|
||||
require.Equal(t, int(response.Count), testCase.ExpectedResult, "expected audit logs count returned")
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+144
-5
@@ -1,24 +1,33 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// AuthorizeFilter takes a list of objects and returns the filtered list of
|
||||
// objects that the user is authorized to perform the given action on.
|
||||
// This is faster than calling Authorize() on each object.
|
||||
func AuthorizeFilter[O rbac.Objecter](h *HTTPAuthorizer, r *http.Request, action rbac.Action, objects []O) ([]O, error) {
|
||||
roles := httpmw.AuthorizationUserRoles(r)
|
||||
objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, action, objects)
|
||||
roles := httpmw.UserAuthorization(r)
|
||||
objects, err := rbac.Filter(r.Context(), h.Authorizer, roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objects)
|
||||
if err != nil {
|
||||
// Log the error as Filter should not be erroring.
|
||||
h.Logger.Error(r.Context(), "filter failed",
|
||||
slog.Error(err),
|
||||
slog.F("user_id", roles.ID),
|
||||
slog.F("username", roles.Username),
|
||||
slog.F("scope", roles.Scope),
|
||||
slog.F("route", r.URL.Path),
|
||||
slog.F("action", action),
|
||||
)
|
||||
@@ -42,7 +51,7 @@ type HTTPAuthorizer struct {
|
||||
// return
|
||||
// }
|
||||
func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
|
||||
return api.httpAuth.Authorize(r, action, object)
|
||||
return api.HTTPAuth.Authorize(r, action, object)
|
||||
}
|
||||
|
||||
// Authorize will return false if the user is not authorized to do the action.
|
||||
@@ -55,8 +64,8 @@ func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objec
|
||||
// return
|
||||
// }
|
||||
func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
|
||||
roles := httpmw.AuthorizationUserRoles(r)
|
||||
err := h.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, action, object.RBACObject())
|
||||
roles := httpmw.UserAuthorization(r)
|
||||
err := h.Authorizer.ByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, object.RBACObject())
|
||||
if err != nil {
|
||||
// Log the errors for debugging
|
||||
internalError := new(rbac.UnauthorizedError)
|
||||
@@ -70,6 +79,7 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
|
||||
slog.F("roles", roles.Roles),
|
||||
slog.F("user_id", roles.ID),
|
||||
slog.F("username", roles.Username),
|
||||
slog.F("scope", roles.Scope),
|
||||
slog.F("route", r.URL.Path),
|
||||
slog.F("action", action),
|
||||
slog.F("object", object),
|
||||
@@ -79,3 +89,132 @@ func (h *HTTPAuthorizer) Authorize(r *http.Request, action rbac.Action, object r
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// AuthorizeSQLFilter returns an authorization filter that can used in a
|
||||
// SQL 'WHERE' clause. If the filter is used, the resulting rows returned
|
||||
// from postgres are already authorized, and the caller does not need to
|
||||
// call 'Authorize()' on the returned objects.
|
||||
// Note the authorization is only for the given action and object type.
|
||||
func (h *HTTPAuthorizer) AuthorizeSQLFilter(r *http.Request, action rbac.Action, objectType string) (rbac.AuthorizeFilter, error) {
|
||||
roles := httpmw.UserAuthorization(r)
|
||||
prepared, err := h.Authorizer.PrepareByRoleName(r.Context(), roles.ID.String(), roles.Roles, roles.Scope.ToRBAC(), roles.Groups, action, objectType)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("prepare filter: %w", err)
|
||||
}
|
||||
|
||||
filter, err := prepared.Compile()
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile filter: %w", err)
|
||||
}
|
||||
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
// checkAuthorization returns if the current API key can use the given
|
||||
// permissions, factoring in the current user's roles and the API key scopes.
|
||||
func (api *API) checkAuthorization(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
auth := httpmw.UserAuthorization(r)
|
||||
|
||||
var params codersdk.AuthorizationRequest
|
||||
if !httpapi.Read(ctx, rw, r, ¶ms) {
|
||||
return
|
||||
}
|
||||
|
||||
api.Logger.Debug(ctx, "check-auth",
|
||||
slog.F("my_id", httpmw.APIKey(r).UserID),
|
||||
slog.F("got_id", auth.ID),
|
||||
slog.F("name", auth.Username),
|
||||
slog.F("roles", auth.Roles), slog.F("scope", auth.Scope),
|
||||
)
|
||||
|
||||
response := make(codersdk.AuthorizationResponse)
|
||||
// Prevent using too many resources by ID. This prevents database abuse
|
||||
// from this endpoint. This also prevents misuse of this endpoint, as
|
||||
// resource_id should be used for single objects, not for a list of them.
|
||||
var (
|
||||
idFetch int
|
||||
maxFetch = 10
|
||||
)
|
||||
for _, v := range params.Checks {
|
||||
if v.Object.ResourceID != "" {
|
||||
idFetch++
|
||||
}
|
||||
}
|
||||
if idFetch > maxFetch {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf(
|
||||
"Endpoint only supports using \"resource_id\" field %d times, found %d usages. Remove %d objects with this field set.",
|
||||
maxFetch, idFetch, idFetch-maxFetch,
|
||||
),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range params.Checks {
|
||||
if v.Object.ResourceType == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Object's \"resource_type\" field must be defined for key %q.", k),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
obj := rbac.Object{
|
||||
Owner: v.Object.OwnerID,
|
||||
OrgID: v.Object.OrganizationID,
|
||||
Type: v.Object.ResourceType,
|
||||
}
|
||||
if obj.Owner == "me" {
|
||||
obj.Owner = auth.ID.String()
|
||||
}
|
||||
|
||||
// If a resource ID is specified, fetch that specific resource.
|
||||
if v.Object.ResourceID != "" {
|
||||
id, err := uuid.Parse(v.Object.ResourceID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Object %q id is not a valid uuid.", v.Object.ResourceID),
|
||||
Validations: []codersdk.ValidationError{{Field: "resource_id", Detail: err.Error()}},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var dbObj rbac.Objecter
|
||||
var dbErr error
|
||||
// Only support referencing some resources by ID.
|
||||
switch v.Object.ResourceType {
|
||||
case rbac.ResourceWorkspaceExecution.Type:
|
||||
wrkSpace, err := api.Database.GetWorkspaceByID(ctx, id)
|
||||
if err == nil {
|
||||
dbObj = wrkSpace.ExecutionRBAC()
|
||||
}
|
||||
dbErr = err
|
||||
case rbac.ResourceWorkspace.Type:
|
||||
dbObj, dbErr = api.Database.GetWorkspaceByID(ctx, id)
|
||||
case rbac.ResourceTemplate.Type:
|
||||
dbObj, dbErr = api.Database.GetTemplateByID(ctx, id)
|
||||
case rbac.ResourceUser.Type:
|
||||
dbObj, dbErr = api.Database.GetUserByID(ctx, id)
|
||||
case rbac.ResourceGroup.Type:
|
||||
dbObj, dbErr = api.Database.GetGroupByID(ctx, id)
|
||||
default:
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("Object type %q does not support \"resource_id\" field.", v.Object.ResourceType),
|
||||
Validations: []codersdk.ValidationError{{Field: "resource_type", Detail: err.Error()}},
|
||||
})
|
||||
return
|
||||
}
|
||||
if dbErr != nil {
|
||||
// 404 or unauthorized is false
|
||||
response[k] = false
|
||||
continue
|
||||
}
|
||||
obj = dbObj.RBACObject()
|
||||
}
|
||||
|
||||
err := api.Authorizer.ByRoleName(r.Context(), auth.ID.String(), auth.Roles, auth.Scope.ToRBAC(), auth.Groups, rbac.Action(v.Action), obj)
|
||||
response[k] = err == nil
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestCheckPermissions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
adminClient := coderdtest.New(t, &coderdtest.Options{
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
// Create adminClient, member, and org adminClient
|
||||
adminUser := coderdtest.CreateFirstUser(t, adminClient)
|
||||
memberClient := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID)
|
||||
memberUser, err := memberClient.User(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
orgAdminClient := coderdtest.CreateAnotherUser(t, adminClient, adminUser.OrganizationID, rbac.RoleOrgAdmin(adminUser.OrganizationID))
|
||||
orgAdminUser, err := orgAdminClient.User(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
|
||||
version := coderdtest.CreateTemplateVersion(t, adminClient, adminUser.OrganizationID, nil)
|
||||
coderdtest.AwaitTemplateVersionJob(t, adminClient, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, adminClient, adminUser.OrganizationID, version.ID)
|
||||
|
||||
// With admin, member, and org admin
|
||||
const (
|
||||
readAllUsers = "read-all-users"
|
||||
readOrgWorkspaces = "read-org-workspaces"
|
||||
readMyself = "read-myself"
|
||||
readOwnWorkspaces = "read-own-workspaces"
|
||||
updateSpecificTemplate = "update-specific-template"
|
||||
)
|
||||
params := map[string]codersdk.AuthorizationCheck{
|
||||
readAllUsers: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: "users",
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
readMyself: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: "users",
|
||||
OwnerID: "me",
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
readOwnWorkspaces: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: "workspaces",
|
||||
OwnerID: "me",
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
readOrgWorkspaces: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: "workspaces",
|
||||
OrganizationID: adminUser.OrganizationID.String(),
|
||||
},
|
||||
Action: "read",
|
||||
},
|
||||
updateSpecificTemplate: {
|
||||
Object: codersdk.AuthorizationObject{
|
||||
ResourceType: rbac.ResourceTemplate.Type,
|
||||
ResourceID: template.ID.String(),
|
||||
},
|
||||
Action: "update",
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Client *codersdk.Client
|
||||
UserID uuid.UUID
|
||||
Check codersdk.AuthorizationResponse
|
||||
}{
|
||||
{
|
||||
Name: "Admin",
|
||||
Client: adminClient,
|
||||
UserID: adminUser.UserID,
|
||||
Check: map[string]bool{
|
||||
readAllUsers: true,
|
||||
readMyself: true,
|
||||
readOwnWorkspaces: true,
|
||||
readOrgWorkspaces: true,
|
||||
updateSpecificTemplate: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "OrgAdmin",
|
||||
Client: orgAdminClient,
|
||||
UserID: orgAdminUser.ID,
|
||||
Check: map[string]bool{
|
||||
readAllUsers: false,
|
||||
readMyself: true,
|
||||
readOwnWorkspaces: true,
|
||||
readOrgWorkspaces: true,
|
||||
updateSpecificTemplate: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Member",
|
||||
Client: memberClient,
|
||||
UserID: memberUser.ID,
|
||||
Check: map[string]bool{
|
||||
readAllUsers: false,
|
||||
readMyself: true,
|
||||
readOwnWorkspaces: true,
|
||||
readOrgWorkspaces: false,
|
||||
updateSpecificTemplate: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range testCases {
|
||||
c := c
|
||||
|
||||
t.Run("CheckAuthorization/"+c.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
resp, err := c.Client.CheckAuthorization(ctx, codersdk.AuthorizationRequest{Checks: params})
|
||||
require.NoError(t, err, "check perms")
|
||||
require.Equal(t, c.Check, resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -276,7 +276,7 @@ func build(ctx context.Context, store database.Store, workspace database.Workspa
|
||||
Provisioner: template.Provisioner,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
StorageMethod: priorJob.StorageMethod,
|
||||
StorageSource: priorJob.StorageSource,
|
||||
FileID: priorJob.FileID,
|
||||
Input: input,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
+172
-83
@@ -1,21 +1,23 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/api/idtoken"
|
||||
"tailscale.com/derp"
|
||||
@@ -25,6 +27,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/buildinfo"
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/awsidentity"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
@@ -34,7 +37,7 @@ import (
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/coderd/turnconn"
|
||||
"github.com/coder/coder/coderd/workspacequota"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/site"
|
||||
@@ -44,13 +47,23 @@ import (
|
||||
// Options are requires parameters for Coder to start.
|
||||
type Options struct {
|
||||
AccessURL *url.URL
|
||||
Logger slog.Logger
|
||||
Database database.Store
|
||||
Pubsub database.Pubsub
|
||||
// AppHostname should be the wildcard hostname to use for workspace
|
||||
// applications INCLUDING the asterisk, (optional) suffix and leading dot.
|
||||
// E.g. "*.apps.coder.com" or "*-apps.coder.com".
|
||||
AppHostname string
|
||||
// AppHostnameRegex contains the regex version of options.AppHostname as
|
||||
// generated by httpapi.CompileHostnamePattern(). It MUST be set if
|
||||
// options.AppHostname is set.
|
||||
AppHostnameRegex *regexp.Regexp
|
||||
Logger slog.Logger
|
||||
Database database.Store
|
||||
Pubsub database.Pubsub
|
||||
|
||||
// CacheDir is used for caching files served by the API.
|
||||
CacheDir string
|
||||
|
||||
Auditor audit.Auditor
|
||||
WorkspaceQuotaEnforcer workspacequota.Enforcer
|
||||
AgentConnectionUpdateFrequency time.Duration
|
||||
AgentInactiveDisconnectTimeout time.Duration
|
||||
// APIRateLimit is the minutely throughput rate limit per user or ip.
|
||||
@@ -64,26 +77,32 @@ type Options struct {
|
||||
GithubOAuth2Config *GithubOAuth2Config
|
||||
OIDCConfig *OIDCConfig
|
||||
PrometheusRegistry *prometheus.Registry
|
||||
ICEServers []webrtc.ICEServer
|
||||
SecureAuthCookie bool
|
||||
SSHKeygenAlgorithm gitsshkey.Algorithm
|
||||
Telemetry telemetry.Reporter
|
||||
TURNServer *turnconn.Server
|
||||
TracerProvider *sdktrace.TracerProvider
|
||||
TracerProvider trace.TracerProvider
|
||||
AutoImportTemplates []AutoImportTemplate
|
||||
LicenseHandler http.Handler
|
||||
FeaturesService FeaturesService
|
||||
|
||||
TailscaleEnable bool
|
||||
TailnetCoordinator *tailnet.Coordinator
|
||||
// TLSCertificates is used to mesh DERP servers securely.
|
||||
TLSCertificates []tls.Certificate
|
||||
TailnetCoordinator tailnet.Coordinator
|
||||
DERPServer *derp.Server
|
||||
DERPMap *tailcfg.DERPMap
|
||||
|
||||
MetricsCacheRefreshInterval time.Duration
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
Experimental bool
|
||||
DeploymentFlags *codersdk.DeploymentFlags
|
||||
}
|
||||
|
||||
// New constructs a Coder API handler.
|
||||
func New(options *Options) *API {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
if options.AppHostname != "" && options.AppHostnameRegex == nil || options.AppHostname == "" && options.AppHostnameRegex != nil {
|
||||
panic("coderd: both AppHostname and AppHostnameRegex must be set or unset")
|
||||
}
|
||||
if options.AgentConnectionUpdateFrequency == 0 {
|
||||
options.AgentConnectionUpdateFrequency = 3 * time.Second
|
||||
}
|
||||
@@ -91,17 +110,23 @@ func New(options *Options) *API {
|
||||
// Multiply the update by two to allow for some lag-time.
|
||||
options.AgentInactiveDisconnectTimeout = options.AgentConnectionUpdateFrequency * 2
|
||||
}
|
||||
if options.AgentStatsRefreshInterval == 0 {
|
||||
options.AgentStatsRefreshInterval = 10 * time.Minute
|
||||
}
|
||||
if options.MetricsCacheRefreshInterval == 0 {
|
||||
options.MetricsCacheRefreshInterval = time.Hour
|
||||
}
|
||||
if options.APIRateLimit == 0 {
|
||||
options.APIRateLimit = 512
|
||||
}
|
||||
if options.AgentStatsRefreshInterval == 0 {
|
||||
options.AgentStatsRefreshInterval = 10 * time.Minute
|
||||
}
|
||||
if options.MetricsCacheRefreshInterval == 0 {
|
||||
options.MetricsCacheRefreshInterval = time.Hour
|
||||
}
|
||||
if options.Authorizer == nil {
|
||||
var err error
|
||||
options.Authorizer, err = rbac.NewAuthorizer()
|
||||
if err != nil {
|
||||
// This should never happen, as the unit tests would fail if the
|
||||
// default built in authorizer failed.
|
||||
panic(xerrors.Errorf("rego authorize panic: %w", err))
|
||||
}
|
||||
options.Authorizer = rbac.NewAuthorizer()
|
||||
}
|
||||
if options.PrometheusRegistry == nil {
|
||||
options.PrometheusRegistry = prometheus.NewRegistry()
|
||||
@@ -109,11 +134,14 @@ func New(options *Options) *API {
|
||||
if options.TailnetCoordinator == nil {
|
||||
options.TailnetCoordinator = tailnet.NewCoordinator()
|
||||
}
|
||||
if options.LicenseHandler == nil {
|
||||
options.LicenseHandler = licenses()
|
||||
if options.DERPServer == nil {
|
||||
options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp")))
|
||||
}
|
||||
if options.FeaturesService == nil {
|
||||
options.FeaturesService = featuresService{}
|
||||
if options.Auditor == nil {
|
||||
options.Auditor = audit.NewNop()
|
||||
}
|
||||
if options.WorkspaceQuotaEnforcer == nil {
|
||||
options.WorkspaceQuotaEnforcer = workspacequota.NewNop()
|
||||
}
|
||||
|
||||
siteCacheDir := options.CacheDir
|
||||
@@ -134,46 +162,86 @@ func New(options *Options) *API {
|
||||
r := chi.NewRouter()
|
||||
api := &API{
|
||||
Options: options,
|
||||
Handler: r,
|
||||
RootHandler: r,
|
||||
siteHandler: site.Handler(site.FS(), binFS),
|
||||
httpAuth: &HTTPAuthorizer{
|
||||
HTTPAuth: &HTTPAuthorizer{
|
||||
Authorizer: options.Authorizer,
|
||||
Logger: options.Logger,
|
||||
},
|
||||
metricsCache: metricsCache,
|
||||
metricsCache: metricsCache,
|
||||
Auditor: atomic.Pointer[audit.Auditor]{},
|
||||
WorkspaceQuotaEnforcer: atomic.Pointer[workspacequota.Enforcer]{},
|
||||
}
|
||||
if options.TailscaleEnable {
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
} else {
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
|
||||
}
|
||||
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
|
||||
api.Auditor.Store(&options.Auditor)
|
||||
api.WorkspaceQuotaEnforcer.Store(&options.WorkspaceQuotaEnforcer)
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
|
||||
oauthConfigs := &httpmw.OAuth2Configs{
|
||||
Github: options.GithubOAuth2Config,
|
||||
OIDC: options.OIDCConfig,
|
||||
}
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
|
||||
|
||||
apiKeyMiddleware := httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: false,
|
||||
Optional: false,
|
||||
})
|
||||
// Same as above but it redirects to the login page.
|
||||
apiKeyMiddlewareRedirect := httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
RedirectToLogin: true,
|
||||
Optional: false,
|
||||
})
|
||||
|
||||
r.Use(
|
||||
httpmw.AttachRequestID,
|
||||
httpmw.Recover(api.Logger),
|
||||
httpmw.Logger(api.Logger),
|
||||
httpmw.Prometheus(options.PrometheusRegistry),
|
||||
// handleSubdomainApplications checks if the first subdomain is a valid
|
||||
// app URL. If it is, it will serve that application.
|
||||
api.handleSubdomainApplications(
|
||||
// Middleware to impose on the served application.
|
||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
// The code handles the the case where the user is not
|
||||
// authenticated automatically.
|
||||
RedirectToLogin: false,
|
||||
Optional: true,
|
||||
}),
|
||||
httpmw.ExtractUserParam(api.Database, false),
|
||||
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
|
||||
),
|
||||
// Build-Version is helpful for debugging.
|
||||
func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Build-Version", buildinfo.Version())
|
||||
w.Header().Add("X-Coder-Build-Version", buildinfo.Version())
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
},
|
||||
httpmw.CSRF(options.SecureAuthCookie),
|
||||
)
|
||||
|
||||
apps := func(r chi.Router) {
|
||||
r.Use(
|
||||
tracing.Middleware(api.TracerProvider),
|
||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
||||
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
|
||||
httpmw.ExtractAPIKey(options.Database, oauthConfigs, true),
|
||||
httpmw.ExtractUserParam(api.Database),
|
||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||
DB: options.Database,
|
||||
OAuth2Configs: oauthConfigs,
|
||||
// Optional is true to allow for public apps. If an
|
||||
// authorization check fails and the user is not authenticated,
|
||||
// they will be redirected to the login page by the app handler.
|
||||
RedirectToLogin: false,
|
||||
Optional: true,
|
||||
}),
|
||||
// Redirect to the login page if the user tries to open an app with
|
||||
// "me" as the username and they are not logged in.
|
||||
httpmw.ExtractUserParam(api.Database, true),
|
||||
// Extracts the <workspace.agent> from the url
|
||||
httpmw.ExtractWorkspaceAndAgentParam(api.Database),
|
||||
)
|
||||
@@ -185,7 +253,7 @@ func New(options *Options) *API {
|
||||
r.Route("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
|
||||
r.Route("/@{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
|
||||
r.Route("/derp", func(r chi.Router) {
|
||||
r.Get("/", derphttp.Handler(api.derpServer).ServeHTTP)
|
||||
r.Get("/", derphttp.Handler(api.DERPServer).ServeHTTP)
|
||||
// This is used when UDP is blocked, and latency must be checked via HTTP(s).
|
||||
r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -193,18 +261,16 @@ func New(options *Options) *API {
|
||||
})
|
||||
|
||||
r.Route("/api/v2", func(r chi.Router) {
|
||||
r.NotFound(func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
|
||||
Message: "Route not found.",
|
||||
})
|
||||
})
|
||||
api.APIHandler = r
|
||||
|
||||
r.NotFound(func(rw http.ResponseWriter, r *http.Request) { httpapi.RouteNotFound(rw) })
|
||||
r.Use(
|
||||
tracing.Middleware(api.TracerProvider),
|
||||
// Specific routes can specify smaller limits.
|
||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
||||
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
|
||||
)
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(w, http.StatusOK, codersdk.Response{
|
||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Response{
|
||||
//nolint:gocritic
|
||||
Message: "👋",
|
||||
})
|
||||
@@ -214,12 +280,16 @@ func New(options *Options) *API {
|
||||
|
||||
r.Route("/buildinfo", func(r chi.Router) {
|
||||
r.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(rw, http.StatusOK, codersdk.BuildInfoResponse{
|
||||
httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{
|
||||
ExternalURL: buildinfo.ExternalURL(),
|
||||
Version: buildinfo.Version(),
|
||||
})
|
||||
})
|
||||
})
|
||||
r.Route("/flags", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/deployment", api.deploymentFlags)
|
||||
})
|
||||
r.Route("/audit", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
@@ -236,9 +306,10 @@ func New(options *Options) *API {
|
||||
// file content is expensive so it should be small.
|
||||
httpmw.RateLimitPerMinute(12),
|
||||
)
|
||||
r.Get("/{hash}", api.fileByHash)
|
||||
r.Get("/{fileID}", api.fileByID)
|
||||
r.Post("/", api.postFile)
|
||||
})
|
||||
|
||||
r.Route("/provisionerdaemons", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
@@ -261,15 +332,15 @@ func New(options *Options) *API {
|
||||
r.Get("/", api.templatesByOrganization)
|
||||
r.Get("/{templatename}", api.templateByOrganizationAndName)
|
||||
})
|
||||
r.Post("/workspaces", api.postWorkspacesByOrganization)
|
||||
r.Route("/members", func(r chi.Router) {
|
||||
r.Get("/roles", api.assignableOrgRoles)
|
||||
r.Route("/{user}", func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.ExtractUserParam(options.Database),
|
||||
httpmw.ExtractUserParam(options.Database, false),
|
||||
httpmw.ExtractOrganizationMemberParam(options.Database),
|
||||
)
|
||||
r.Put("/roles", api.putMemberRoles)
|
||||
r.Post("/workspaces", api.postWorkspacesByOrganization)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -344,7 +415,8 @@ func New(options *Options) *API {
|
||||
r.Get("/", api.assignableSiteRoles)
|
||||
})
|
||||
r.Route("/{user}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractUserParam(options.Database))
|
||||
r.Use(httpmw.ExtractUserParam(options.Database, false))
|
||||
r.Delete("/", api.deleteUser)
|
||||
r.Get("/", api.userByName)
|
||||
r.Put("/profile", api.putUserProfile)
|
||||
r.Route("/status", func(r chi.Router) {
|
||||
@@ -358,11 +430,16 @@ func New(options *Options) *API {
|
||||
r.Put("/roles", api.putUserRoles)
|
||||
r.Get("/roles", api.userRoles)
|
||||
|
||||
r.Post("/authorization", api.checkPermissions)
|
||||
|
||||
r.Route("/keys", func(r chi.Router) {
|
||||
r.Post("/", api.postAPIKey)
|
||||
r.Get("/{keyid}", api.apiKey)
|
||||
r.Route("/tokens", func(r chi.Router) {
|
||||
r.Post("/", api.postToken)
|
||||
r.Get("/", api.tokens)
|
||||
})
|
||||
r.Route("/{keyid}", func(r chi.Router) {
|
||||
r.Get("/", api.apiKey)
|
||||
r.Delete("/", api.deleteAPIKey)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/organizations", func(r chi.Router) {
|
||||
@@ -384,16 +461,12 @@ func New(options *Options) *API {
|
||||
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
|
||||
r.Route("/me", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
|
||||
r.Get("/apps", api.workspaceAgentApps)
|
||||
r.Get("/metadata", api.workspaceAgentMetadata)
|
||||
r.Post("/version", api.postWorkspaceAgentVersion)
|
||||
r.Get("/listen", api.workspaceAgentListen)
|
||||
|
||||
r.Post("/app-health", api.postWorkspaceAppHealth)
|
||||
r.Get("/gitsshkey", api.agentGitSSHKey)
|
||||
r.Get("/turn", api.workspaceAgentTurn)
|
||||
r.Get("/iceservers", api.workspaceAgentICEServers)
|
||||
|
||||
r.Get("/coordinate", api.workspaceAgentCoordinate)
|
||||
|
||||
r.Get("/report-stats", api.workspaceAgentReportStats)
|
||||
})
|
||||
r.Route("/{workspaceagent}", func(r chi.Router) {
|
||||
@@ -403,23 +476,20 @@ func New(options *Options) *API {
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.workspaceAgent)
|
||||
r.Get("/dial", api.workspaceAgentDial)
|
||||
r.Get("/turn", api.userWorkspaceAgentTurn)
|
||||
r.Get("/pty", api.workspaceAgentPTY)
|
||||
r.Get("/iceservers", api.workspaceAgentICEServers)
|
||||
|
||||
r.Get("/listening-ports", api.workspaceAgentListeningPorts)
|
||||
r.Get("/connection", api.workspaceAgentConnection)
|
||||
r.Get("/coordinate", api.workspaceAgentClientCoordinate)
|
||||
// TODO: This can be removed in October. It allows for a friendly
|
||||
// error message when transitioning from WebRTC to Tailscale. See:
|
||||
// https://github.com/coder/coder/issues/4126
|
||||
r.Get("/dial", func(w http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(r.Context(), w, http.StatusGone, codersdk.Response{
|
||||
Message: "Your Coder CLI is out of date, and requires v0.8.15+ to connect!",
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
r.Route("/workspaceresources/{workspaceresource}", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractWorkspaceResourceParam(options.Database),
|
||||
httpmw.ExtractWorkspaceParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.workspaceResource)
|
||||
})
|
||||
r.Route("/workspaces", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
@@ -457,13 +527,24 @@ func New(options *Options) *API {
|
||||
r.Get("/resources", api.workspaceBuildResources)
|
||||
r.Get("/state", api.workspaceBuildState)
|
||||
})
|
||||
r.Route("/entitlements", func(r chi.Router) {
|
||||
r.Route("/authcheck", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/", api.FeaturesService.EntitlementsAPI)
|
||||
r.Post("/", api.checkAuthorization)
|
||||
})
|
||||
r.Route("/licenses", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Mount("/", options.LicenseHandler)
|
||||
r.Route("/applications", func(r chi.Router) {
|
||||
r.Route("/host", func(r chi.Router) {
|
||||
// Don't leak the hostname to unauthenticated users.
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/", api.appHost)
|
||||
})
|
||||
r.Route("/auth-redirect", func(r chi.Router) {
|
||||
// We want to redirect to login if they are not authenticated.
|
||||
r.Use(apiKeyMiddlewareRedirect)
|
||||
|
||||
// This is a GET request as it's redirected to by the subdomain app
|
||||
// handler and the login page.
|
||||
r.Get("/", api.workspaceApplicationAuth)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -473,17 +554,22 @@ func New(options *Options) *API {
|
||||
|
||||
type API struct {
|
||||
*Options
|
||||
Auditor atomic.Pointer[audit.Auditor]
|
||||
WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool]
|
||||
WorkspaceQuotaEnforcer atomic.Pointer[workspacequota.Enforcer]
|
||||
TailnetCoordinator atomic.Pointer[tailnet.Coordinator]
|
||||
HTTPAuth *HTTPAuthorizer
|
||||
|
||||
derpServer *derp.Server
|
||||
// APIHandler serves "/api/v2"
|
||||
APIHandler chi.Router
|
||||
// RootHandler serves "/"
|
||||
RootHandler chi.Router
|
||||
|
||||
Handler chi.Router
|
||||
metricsCache *metricscache.Cache
|
||||
siteHandler http.Handler
|
||||
websocketWaitMutex sync.Mutex
|
||||
websocketWaitGroup sync.WaitGroup
|
||||
workspaceAgentCache *wsconncache.Cache
|
||||
httpAuth *HTTPAuthorizer
|
||||
|
||||
metricsCache *metricscache.Cache
|
||||
}
|
||||
|
||||
// Close waits for all WebSocket connections to drain before returning.
|
||||
@@ -493,7 +579,10 @@ func (api *API) Close() error {
|
||||
api.websocketWaitMutex.Unlock()
|
||||
|
||||
api.metricsCache.Close()
|
||||
|
||||
coordinator := api.TailnetCoordinator.Load()
|
||||
if coordinator != nil {
|
||||
_ = (*coordinator).Close()
|
||||
}
|
||||
return api.workspaceAgentCache.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -38,16 +38,6 @@ func TestBuildInfo(t *testing.T) {
|
||||
require.Equal(t, buildinfo.Version(), buildInfo.Version, "version")
|
||||
}
|
||||
|
||||
// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered.
|
||||
func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
a := coderdtest.NewAuthTester(ctx, t, nil)
|
||||
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
|
||||
a.Test(ctx, assertRoute, skipRoutes)
|
||||
}
|
||||
|
||||
func TestDERP(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
|
||||
@@ -22,147 +22,12 @@ import (
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
type RouteCheck struct {
|
||||
NoAuthorize bool
|
||||
AssertAction rbac.Action
|
||||
AssertObject rbac.Object
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
type AuthTester struct {
|
||||
t *testing.T
|
||||
api *coderd.API
|
||||
authorizer *recordingAuthorizer
|
||||
|
||||
Client *codersdk.Client
|
||||
Workspace codersdk.Workspace
|
||||
Organization codersdk.Organization
|
||||
Admin codersdk.CreateFirstUserResponse
|
||||
Template codersdk.Template
|
||||
Version codersdk.TemplateVersion
|
||||
WorkspaceResource codersdk.WorkspaceResource
|
||||
File codersdk.UploadResponse
|
||||
TemplateVersionDryRun codersdk.ProvisionerJob
|
||||
TemplateParam codersdk.Parameter
|
||||
URLParams map[string]string
|
||||
}
|
||||
|
||||
func NewAuthTester(ctx context.Context, t *testing.T, options *Options) *AuthTester {
|
||||
authorizer := &recordingAuthorizer{}
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
if options.Authorizer != nil {
|
||||
t.Error("NewAuthTester cannot be called with custom Authorizer")
|
||||
}
|
||||
options.Authorizer = authorizer
|
||||
options.IncludeProvisionerDaemon = true
|
||||
|
||||
client, _, api := newWithAPI(t, options)
|
||||
admin := CreateFirstUser(t, client)
|
||||
// The provisioner will call to coderd and register itself. This is async,
|
||||
// so we wait for it to occur.
|
||||
require.Eventually(t, func() bool {
|
||||
provisionerds, err := client.ProvisionerDaemons(ctx)
|
||||
return assert.NoError(t, err) && len(provisionerds) > 0
|
||||
}, testutil.WaitLong, testutil.IntervalSlow)
|
||||
|
||||
provisionerds, err := client.ProvisionerDaemons(ctx)
|
||||
require.NoError(t, err, "fetch provisioners")
|
||||
require.Len(t, provisionerds, 1)
|
||||
|
||||
organization, err := client.Organization(ctx, admin.OrganizationID)
|
||||
require.NoError(t, err, "fetch org")
|
||||
|
||||
// Setup some data in the database.
|
||||
version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
// Return a workspace resource
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "some",
|
||||
Type: "example",
|
||||
Agents: []*proto.Agent{{
|
||||
Name: "agent",
|
||||
Id: "something",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Apps: []*proto.App{{
|
||||
Name: "testapp",
|
||||
Url: "http://localhost:3000",
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
AwaitTemplateVersionJob(t, client, version.ID)
|
||||
template := CreateTemplate(t, client, admin.OrganizationID, version.ID)
|
||||
workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID)
|
||||
AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024))
|
||||
require.NoError(t, err, "upload file")
|
||||
workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
|
||||
require.NoError(t, err, "workspace resources")
|
||||
templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{
|
||||
ParameterValues: []codersdk.CreateParameterRequest{},
|
||||
})
|
||||
require.NoError(t, err, "template version dry-run")
|
||||
|
||||
templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{
|
||||
Name: "test-param",
|
||||
SourceValue: "hello world",
|
||||
SourceScheme: codersdk.ParameterSourceSchemeData,
|
||||
DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable,
|
||||
})
|
||||
require.NoError(t, err, "create template param")
|
||||
|
||||
urlParameters := map[string]string{
|
||||
"{organization}": admin.OrganizationID.String(),
|
||||
"{user}": admin.UserID.String(),
|
||||
"{organizationname}": organization.Name,
|
||||
"{workspace}": workspace.ID.String(),
|
||||
"{workspacebuild}": workspace.LatestBuild.ID.String(),
|
||||
"{workspacename}": workspace.Name,
|
||||
"{workspaceagent}": workspaceResources[0].Agents[0].ID.String(),
|
||||
"{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10),
|
||||
"{template}": template.ID.String(),
|
||||
"{hash}": file.Hash,
|
||||
"{workspaceresource}": workspaceResources[0].ID.String(),
|
||||
"{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name,
|
||||
"{templateversion}": version.ID.String(),
|
||||
"{jobID}": templateVersionDryRun.ID.String(),
|
||||
"{templatename}": template.Name,
|
||||
"{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name,
|
||||
// Only checking template scoped params here
|
||||
"parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s",
|
||||
string(templateParam.Scope), templateParam.ScopeID.String()),
|
||||
}
|
||||
|
||||
return &AuthTester{
|
||||
t: t,
|
||||
api: api,
|
||||
authorizer: authorizer,
|
||||
Client: client,
|
||||
Workspace: workspace,
|
||||
Organization: organization,
|
||||
Admin: admin,
|
||||
Template: template,
|
||||
Version: version,
|
||||
WorkspaceResource: workspaceResources[0],
|
||||
File: file,
|
||||
TemplateVersionDryRun: templateVersionDryRun,
|
||||
TemplateParam: templateParam,
|
||||
URLParams: urlParameters,
|
||||
}
|
||||
}
|
||||
|
||||
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
// Some quick reused objects
|
||||
workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||
workspaceExecObj := rbac.ResourceWorkspaceExecution.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||
applicationConnectObj := rbac.ResourceWorkspaceApplicationConnect.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
|
||||
|
||||
// skipRoutes allows skipping routes from being checked.
|
||||
skipRoutes := map[string]string{
|
||||
"POST:/api/v2/users/logout": "Logging out deletes the API Key for other routes",
|
||||
@@ -179,25 +44,26 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
"POST:/api/v2/users/login": {NoAuthorize: true},
|
||||
"GET:/api/v2/users/authmethods": {NoAuthorize: true},
|
||||
"POST:/api/v2/csp/reports": {NoAuthorize: true},
|
||||
"GET:/api/v2/entitlements": {NoAuthorize: true},
|
||||
"POST:/api/v2/authcheck": {NoAuthorize: true},
|
||||
"GET:/api/v2/applications/host": {NoAuthorize: true},
|
||||
// This is a dummy endpoint for compatibility with older CLI versions.
|
||||
"GET:/api/v2/workspaceagents/{workspaceagent}/dial": {NoAuthorize: true},
|
||||
|
||||
// Has it's own auth
|
||||
"GET:/api/v2/users/oauth2/github/callback": {NoAuthorize: true},
|
||||
"GET:/api/v2/users/oidc/callback": {NoAuthorize: true},
|
||||
|
||||
// All workspaceagents endpoints do not use rbac
|
||||
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/iceservers": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/apps": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaceagents/me/app-health": {NoAuthorize: true},
|
||||
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
|
||||
|
||||
// These endpoints have more assertions. This is good, add more endpoints to assert if you can!
|
||||
"GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)},
|
||||
@@ -234,10 +100,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
AssertAction: rbac.ActionUpdate,
|
||||
AssertObject: workspaceRBACObj,
|
||||
},
|
||||
"GET:/api/v2/workspaceresources/{workspaceresource}": {
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: workspaceRBACObj,
|
||||
},
|
||||
"PATCH:/api/v2/workspacebuilds/{workspacebuild}/cancel": {
|
||||
AssertAction: rbac.ActionUpdate,
|
||||
AssertObject: workspaceRBACObj,
|
||||
@@ -254,14 +116,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: workspaceRBACObj,
|
||||
},
|
||||
"GET:/api/v2/workspaceagents/{workspaceagent}/dial": {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
},
|
||||
"GET:/api/v2/workspaceagents/{workspaceagent}/turn": {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
},
|
||||
"GET:/api/v2/workspaceagents/{workspaceagent}/pty": {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
@@ -270,11 +124,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
},
|
||||
"GET:/api/v2/workspaces/": {
|
||||
StatusCode: http.StatusOK,
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: workspaceRBACObj,
|
||||
},
|
||||
"GET:/api/v2/organizations/{organization}/templates": {
|
||||
StatusCode: http.StatusOK,
|
||||
AssertAction: rbac.ActionRead,
|
||||
@@ -293,7 +142,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID),
|
||||
},
|
||||
"POST:/api/v2/files": {AssertAction: rbac.ActionCreate, AssertObject: rbac.ResourceFile},
|
||||
"GET:/api/v2/files/{hash}": {
|
||||
"GET:/api/v2/files/{fileID}": {
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceFile.WithOwner(a.Admin.UserID.String()),
|
||||
},
|
||||
@@ -375,7 +224,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: rbac.ResourceTemplate.InOrg(a.Template.OrganizationID),
|
||||
},
|
||||
"POST:/api/v2/organizations/{organization}/workspaces": {
|
||||
"POST:/api/v2/organizations/{organization}/members/{user}/workspaces": {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
// No ID when creating
|
||||
AssertObject: workspaceRBACObj,
|
||||
@@ -384,13 +233,17 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: workspaceRBACObj,
|
||||
},
|
||||
"GET:/api/v2/users": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceUser},
|
||||
"GET:/api/v2/users": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceUser},
|
||||
"GET:/api/v2/applications/auth-redirect": {AssertAction: rbac.ActionCreate, AssertObject: rbac.ResourceAPIKey},
|
||||
|
||||
// These endpoints need payloads to get to the auth part. Payloads will be required
|
||||
"PUT:/api/v2/users/{user}/roles": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||
"PUT:/api/v2/organizations/{organization}/members/{user}/roles": {NoAuthorize: true},
|
||||
"POST:/api/v2/workspaces/{workspace}/builds": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||
"POST:/api/v2/organizations/{organization}/templateversions": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||
|
||||
// Endpoints that use the SQLQuery filter.
|
||||
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
|
||||
}
|
||||
|
||||
// Routes like proxy routes support all HTTP methods. A helper func to expand
|
||||
@@ -408,16 +261,144 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
|
||||
assertAllHTTPMethods("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}/*", RouteCheck{
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
AssertObject: applicationConnectObj,
|
||||
})
|
||||
assertAllHTTPMethods("/@{user}/{workspace_and_agent}/apps/{workspaceapp}/*", RouteCheck{
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: workspaceExecObj,
|
||||
AssertObject: applicationConnectObj,
|
||||
})
|
||||
|
||||
return skipRoutes, assertRoute
|
||||
}
|
||||
|
||||
type RouteCheck struct {
|
||||
NoAuthorize bool
|
||||
AssertAction rbac.Action
|
||||
AssertObject rbac.Object
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
type AuthTester struct {
|
||||
t *testing.T
|
||||
api *coderd.API
|
||||
authorizer *RecordingAuthorizer
|
||||
|
||||
Client *codersdk.Client
|
||||
Workspace codersdk.Workspace
|
||||
Organization codersdk.Organization
|
||||
Admin codersdk.CreateFirstUserResponse
|
||||
Template codersdk.Template
|
||||
Version codersdk.TemplateVersion
|
||||
WorkspaceResource codersdk.WorkspaceResource
|
||||
File codersdk.UploadResponse
|
||||
TemplateVersionDryRun codersdk.ProvisionerJob
|
||||
TemplateParam codersdk.Parameter
|
||||
URLParams map[string]string
|
||||
}
|
||||
|
||||
func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, api *coderd.API, admin codersdk.CreateFirstUserResponse) *AuthTester {
|
||||
authorizer, ok := api.Authorizer.(*RecordingAuthorizer)
|
||||
if !ok {
|
||||
t.Fail()
|
||||
}
|
||||
// The provisioner will call to coderd and register itself. This is async,
|
||||
// so we wait for it to occur.
|
||||
require.Eventually(t, func() bool {
|
||||
provisionerds, err := client.ProvisionerDaemons(ctx)
|
||||
return assert.NoError(t, err) && len(provisionerds) > 0
|
||||
}, testutil.WaitLong, testutil.IntervalSlow)
|
||||
|
||||
provisionerds, err := client.ProvisionerDaemons(ctx)
|
||||
require.NoError(t, err, "fetch provisioners")
|
||||
require.Len(t, provisionerds, 1)
|
||||
|
||||
organization, err := client.Organization(ctx, admin.OrganizationID)
|
||||
require.NoError(t, err, "fetch org")
|
||||
|
||||
// Setup some data in the database.
|
||||
version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{
|
||||
Parse: echo.ParseComplete,
|
||||
Provision: []*proto.Provision_Response{{
|
||||
Type: &proto.Provision_Response_Complete{
|
||||
Complete: &proto.Provision_Complete{
|
||||
// Return a workspace resource
|
||||
Resources: []*proto.Resource{{
|
||||
Name: "some",
|
||||
Type: "example",
|
||||
Agents: []*proto.Agent{{
|
||||
Name: "agent",
|
||||
Id: "something",
|
||||
Auth: &proto.Agent_Token{},
|
||||
Apps: []*proto.App{{
|
||||
Name: "testapp",
|
||||
Url: "http://localhost:3000",
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
})
|
||||
AwaitTemplateVersionJob(t, client, version.ID)
|
||||
template := CreateTemplate(t, client, admin.OrganizationID, version.ID)
|
||||
workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID)
|
||||
AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024))
|
||||
require.NoError(t, err, "upload file")
|
||||
workspace, err = client.Workspace(ctx, workspace.ID)
|
||||
require.NoError(t, err, "workspace resources")
|
||||
templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{
|
||||
ParameterValues: []codersdk.CreateParameterRequest{},
|
||||
})
|
||||
require.NoError(t, err, "template version dry-run")
|
||||
|
||||
templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{
|
||||
Name: "test-param",
|
||||
SourceValue: "hello world",
|
||||
SourceScheme: codersdk.ParameterSourceSchemeData,
|
||||
DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable,
|
||||
})
|
||||
require.NoError(t, err, "create template param")
|
||||
urlParameters := map[string]string{
|
||||
"{organization}": admin.OrganizationID.String(),
|
||||
"{user}": admin.UserID.String(),
|
||||
"{organizationname}": organization.Name,
|
||||
"{workspace}": workspace.ID.String(),
|
||||
"{workspacebuild}": workspace.LatestBuild.ID.String(),
|
||||
"{workspacename}": workspace.Name,
|
||||
"{workspaceagent}": workspace.LatestBuild.Resources[0].Agents[0].ID.String(),
|
||||
"{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10),
|
||||
"{template}": template.ID.String(),
|
||||
"{fileID}": file.ID.String(),
|
||||
"{workspaceresource}": workspace.LatestBuild.Resources[0].ID.String(),
|
||||
"{workspaceapp}": workspace.LatestBuild.Resources[0].Agents[0].Apps[0].Name,
|
||||
"{templateversion}": version.ID.String(),
|
||||
"{jobID}": templateVersionDryRun.ID.String(),
|
||||
"{templatename}": template.Name,
|
||||
"{workspace_and_agent}": workspace.Name + "." + workspace.LatestBuild.Resources[0].Agents[0].Name,
|
||||
// Only checking template scoped params here
|
||||
"parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s",
|
||||
string(templateParam.Scope), templateParam.ScopeID.String()),
|
||||
}
|
||||
|
||||
return &AuthTester{
|
||||
t: t,
|
||||
api: api,
|
||||
authorizer: authorizer,
|
||||
Client: client,
|
||||
Workspace: workspace,
|
||||
Organization: organization,
|
||||
Admin: admin,
|
||||
Template: template,
|
||||
Version: version,
|
||||
WorkspaceResource: workspace.LatestBuild.Resources[0],
|
||||
File: file,
|
||||
TemplateVersionDryRun: templateVersionDryRun,
|
||||
TemplateParam: templateParam,
|
||||
URLParams: urlParameters,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) {
|
||||
// Always fail auth from this point forward
|
||||
a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil)
|
||||
@@ -443,7 +424,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
|
||||
}
|
||||
|
||||
err := chi.Walk(
|
||||
a.api.Handler,
|
||||
a.api.RootHandler,
|
||||
func(
|
||||
method string,
|
||||
route string,
|
||||
@@ -518,45 +499,88 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
|
||||
type authCall struct {
|
||||
SubjectID string
|
||||
Roles []string
|
||||
Groups []string
|
||||
Scope rbac.Scope
|
||||
Action rbac.Action
|
||||
Object rbac.Object
|
||||
}
|
||||
|
||||
type recordingAuthorizer struct {
|
||||
type RecordingAuthorizer struct {
|
||||
Called *authCall
|
||||
AlwaysReturn error
|
||||
}
|
||||
|
||||
func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, action rbac.Action, object rbac.Object) error {
|
||||
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
|
||||
|
||||
// ByRoleNameSQL does not record the call. This matches the postgres behavior
|
||||
// of not calling Authorize()
|
||||
func (r *RecordingAuthorizer) ByRoleNameSQL(_ context.Context, _ string, _ []string, _ rbac.Scope, _ []string, _ rbac.Action, _ rbac.Object) error {
|
||||
return r.AlwaysReturn
|
||||
}
|
||||
|
||||
func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, groups []string, action rbac.Action, object rbac.Object) error {
|
||||
r.Called = &authCall{
|
||||
SubjectID: subjectID,
|
||||
Roles: roleNames,
|
||||
Groups: groups,
|
||||
Scope: scope,
|
||||
Action: action,
|
||||
Object: object,
|
||||
}
|
||||
return r.AlwaysReturn
|
||||
}
|
||||
|
||||
func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
|
||||
func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, groups []string, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
|
||||
return &fakePreparedAuthorizer{
|
||||
Original: r,
|
||||
SubjectID: subjectID,
|
||||
Roles: roles,
|
||||
Action: action,
|
||||
Original: r,
|
||||
SubjectID: subjectID,
|
||||
Roles: roles,
|
||||
Scope: scope,
|
||||
Action: action,
|
||||
HardCodedSQLString: "true",
|
||||
Groups: groups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *recordingAuthorizer) reset() {
|
||||
func (r *RecordingAuthorizer) reset() {
|
||||
r.Called = nil
|
||||
}
|
||||
|
||||
type fakePreparedAuthorizer struct {
|
||||
Original *recordingAuthorizer
|
||||
SubjectID string
|
||||
Roles []string
|
||||
Action rbac.Action
|
||||
Original *RecordingAuthorizer
|
||||
SubjectID string
|
||||
Roles []string
|
||||
Scope rbac.Scope
|
||||
Action rbac.Action
|
||||
Groups []string
|
||||
HardCodedSQLString string
|
||||
HardCodedRegoString string
|
||||
}
|
||||
|
||||
func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error {
|
||||
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Action, object)
|
||||
return f.Original.ByRoleName(ctx, f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object)
|
||||
}
|
||||
|
||||
// Compile returns a compiled version of the authorizer that will work for
|
||||
// in memory databases. This fake version will not work against a SQL database.
|
||||
func (f *fakePreparedAuthorizer) Compile() (rbac.AuthorizeFilter, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
|
||||
return f.Original.ByRoleNameSQL(context.Background(), f.SubjectID, f.Roles, f.Scope, f.Groups, f.Action, object) == nil
|
||||
}
|
||||
|
||||
func (f fakePreparedAuthorizer) RegoString() string {
|
||||
if f.HardCodedRegoString != "" {
|
||||
return f.HardCodedRegoString
|
||||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (f fakePreparedAuthorizer) SQLString(_ rbac.SQLConfig) string {
|
||||
if f.HardCodedSQLString != "" {
|
||||
return f.HardCodedSQLString
|
||||
}
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package coderdtest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
)
|
||||
|
||||
func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
|
||||
// Required for any subdomain-based proxy tests to pass.
|
||||
AppHostname: "*.test.coder.com",
|
||||
Authorizer: &coderdtest.RecordingAuthorizer{},
|
||||
IncludeProvisionerDaemon: true,
|
||||
})
|
||||
admin := coderdtest.CreateFirstUser(t, client)
|
||||
a := coderdtest.NewAuthTester(context.Background(), t, client, api, admin)
|
||||
skipRoute, assertRoute := coderdtest.AGPLRoutes(a)
|
||||
a.Test(context.Background(), assertRoute, skipRoute)
|
||||
}
|
||||
+142
-99
@@ -7,9 +7,9 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
@@ -21,9 +21,10 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -38,20 +39,24 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
"google.golang.org/api/idtoken"
|
||||
"google.golang.org/api/option"
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/net/stun/stuntest"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/nettype"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/autobuild/executor"
|
||||
"github.com/coder/coder/coderd/awsidentity"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/coderd/gitsshkey"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/coderd/turnconn"
|
||||
"github.com/coder/coder/coderd/util/ptr"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
@@ -59,12 +64,15 @@ import (
|
||||
"github.com/coder/coder/provisionerd"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
"github.com/coder/coder/provisionersdk/proto"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
AppHostname string
|
||||
AWSCertificates awsidentity.Certificates
|
||||
Authorizer rbac.Authorizer
|
||||
Experimental bool
|
||||
AzureCertificates x509.VerifyOptions
|
||||
GithubOAuth2Config *coderd.GithubOAuth2Config
|
||||
OIDCConfig *coderd.OIDCConfig
|
||||
@@ -74,12 +82,20 @@ type Options struct {
|
||||
AutoImportTemplates []coderd.AutoImportTemplate
|
||||
AutobuildTicker <-chan time.Time
|
||||
AutobuildStats chan<- executor.Stats
|
||||
Auditor audit.Auditor
|
||||
TLSCertificates []tls.Certificate
|
||||
|
||||
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
|
||||
IncludeProvisionerDaemon bool
|
||||
APIBuilder func(*coderd.Options) *coderd.API
|
||||
MetricsCacheRefreshInterval time.Duration
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
DeploymentFlags *codersdk.DeploymentFlags
|
||||
|
||||
// Overriding the database is heavily discouraged.
|
||||
// It should only be used in cases where multiple Coder
|
||||
// test instances are running against the same database.
|
||||
Database database.Store
|
||||
Pubsub database.Pubsub
|
||||
}
|
||||
|
||||
// New constructs a codersdk client connected to an in-memory API instance.
|
||||
@@ -109,14 +125,11 @@ func NewWithProvisionerCloser(t *testing.T, options *Options) (*codersdk.Client,
|
||||
// and is a temporary measure while the API to register provisioners is ironed
|
||||
// out.
|
||||
func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) {
|
||||
client, closer, _ := newWithAPI(t, options)
|
||||
client, closer, _ := NewWithAPI(t, options)
|
||||
return client, closer
|
||||
}
|
||||
|
||||
// newWithAPI constructs an in-memory API instance and returns a client to talk to it.
|
||||
// Most tests never need a reference to the API, but AuthorizationTest in this module uses it.
|
||||
// Do not expose the API or wrath shall descend upon thee.
|
||||
func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
|
||||
func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.CancelFunc, *coderd.Options) {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
@@ -137,124 +150,147 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
|
||||
close(options.AutobuildStats)
|
||||
})
|
||||
}
|
||||
if options.APIBuilder == nil {
|
||||
options.APIBuilder = coderd.New
|
||||
}
|
||||
|
||||
// This can be hotswapped for a live database instance.
|
||||
db := databasefake.New()
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
if os.Getenv("DB") != "" {
|
||||
connectionURL, closePg, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(closePg)
|
||||
sqlDB, err := sql.Open("postgres", connectionURL)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = sqlDB.Close()
|
||||
})
|
||||
db = database.New(sqlDB)
|
||||
|
||||
pubsub, err = database.NewPubsub(context.Background(), sqlDB, connectionURL)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = pubsub.Close()
|
||||
})
|
||||
if options.Database == nil {
|
||||
options.Database, options.Pubsub = dbtestutil.NewDB(t)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer t.Cleanup(cancelFunc) // Defer to ensure cancelFunc is executed first.
|
||||
|
||||
lifecycleExecutor := executor.New(
|
||||
ctx,
|
||||
db,
|
||||
options.Database,
|
||||
slogtest.Make(t, nil).Named("autobuild.executor").Leveled(slog.LevelDebug),
|
||||
options.AutobuildTicker,
|
||||
).WithStatsChannel(options.AutobuildStats)
|
||||
lifecycleExecutor.Run()
|
||||
|
||||
srv := httptest.NewUnstartedServer(nil)
|
||||
var mutex sync.RWMutex
|
||||
var handler http.Handler
|
||||
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mutex.RLock()
|
||||
defer mutex.RUnlock()
|
||||
if handler != nil {
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
}))
|
||||
srv.Config.BaseContext = func(_ net.Listener) context.Context {
|
||||
return ctx
|
||||
}
|
||||
srv.Start()
|
||||
if options.TLSCertificates != nil {
|
||||
srv.TLS = &tls.Config{
|
||||
Certificates: options.TLSCertificates,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
srv.StartTLS()
|
||||
} else {
|
||||
srv.Start()
|
||||
}
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
tcpAddr, ok := srv.Listener.Addr().(*net.TCPAddr)
|
||||
require.True(t, ok)
|
||||
|
||||
serverURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
serverURL.Host = fmt.Sprintf("localhost:%d", tcpAddr.Port)
|
||||
|
||||
derpPort, err := strconv.Atoi(serverURL.Port())
|
||||
require.NoError(t, err)
|
||||
|
||||
stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{})
|
||||
t.Cleanup(stunCleanup)
|
||||
|
||||
derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(slogtest.Make(t, nil).Named("derp")))
|
||||
derpServer.SetMeshKey("test-key")
|
||||
|
||||
// match default with cli default
|
||||
if options.SSHKeygenAlgorithm == "" {
|
||||
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
|
||||
}
|
||||
|
||||
turnServer, err := turnconn.New(nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = turnServer.Close()
|
||||
})
|
||||
var appHostnameRegex *regexp.Regexp
|
||||
if options.AppHostname != "" {
|
||||
var err error
|
||||
appHostnameRegex, err = httpapi.CompileHostnamePattern(options.AppHostname)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// We set the handler after server creation for the access URL.
|
||||
coderAPI := options.APIBuilder(&coderd.Options{
|
||||
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
|
||||
// Force a long disconnection timeout to ensure
|
||||
// agents are not marked as disconnected during slow tests.
|
||||
AgentInactiveDisconnectTimeout: testutil.WaitShort,
|
||||
AccessURL: serverURL,
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
|
||||
CacheDir: t.TempDir(),
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
return func(h http.Handler) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
handler = h
|
||||
}, cancelFunc, &coderd.Options{
|
||||
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
|
||||
// Force a long disconnection timeout to ensure
|
||||
// agents are not marked as disconnected during slow tests.
|
||||
AgentInactiveDisconnectTimeout: testutil.WaitShort,
|
||||
AccessURL: serverURL,
|
||||
AppHostname: options.AppHostname,
|
||||
AppHostnameRegex: appHostnameRegex,
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
|
||||
CacheDir: t.TempDir(),
|
||||
Database: options.Database,
|
||||
Pubsub: options.Pubsub,
|
||||
Experimental: options.Experimental,
|
||||
|
||||
AWSCertificates: options.AWSCertificates,
|
||||
AzureCertificates: options.AzureCertificates,
|
||||
GithubOAuth2Config: options.GithubOAuth2Config,
|
||||
OIDCConfig: options.OIDCConfig,
|
||||
GoogleTokenValidator: options.GoogleTokenValidator,
|
||||
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
|
||||
TURNServer: turnServer,
|
||||
APIRateLimit: options.APIRateLimit,
|
||||
Authorizer: options.Authorizer,
|
||||
Telemetry: telemetry.NewNoop(),
|
||||
DERPMap: &tailcfg.DERPMap{
|
||||
Regions: map[int]*tailcfg.DERPRegion{
|
||||
1: {
|
||||
RegionID: 1,
|
||||
RegionCode: "coder",
|
||||
RegionName: "Coder",
|
||||
Nodes: []*tailcfg.DERPNode{{
|
||||
Name: "1a",
|
||||
RegionID: 1,
|
||||
IPv4: "127.0.0.1",
|
||||
DERPPort: derpPort,
|
||||
STUNPort: -1,
|
||||
InsecureForTests: true,
|
||||
ForceHTTP: true,
|
||||
}},
|
||||
Auditor: options.Auditor,
|
||||
AWSCertificates: options.AWSCertificates,
|
||||
AzureCertificates: options.AzureCertificates,
|
||||
GithubOAuth2Config: options.GithubOAuth2Config,
|
||||
OIDCConfig: options.OIDCConfig,
|
||||
GoogleTokenValidator: options.GoogleTokenValidator,
|
||||
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
|
||||
DERPServer: derpServer,
|
||||
APIRateLimit: options.APIRateLimit,
|
||||
Authorizer: options.Authorizer,
|
||||
Telemetry: telemetry.NewNoop(),
|
||||
TLSCertificates: options.TLSCertificates,
|
||||
DERPMap: &tailcfg.DERPMap{
|
||||
Regions: map[int]*tailcfg.DERPRegion{
|
||||
1: {
|
||||
EmbeddedRelay: true,
|
||||
RegionID: 1,
|
||||
RegionCode: "coder",
|
||||
RegionName: "Coder",
|
||||
Nodes: []*tailcfg.DERPNode{{
|
||||
Name: "1a",
|
||||
RegionID: 1,
|
||||
IPv4: "127.0.0.1",
|
||||
DERPPort: derpPort,
|
||||
STUNPort: stunAddr.Port,
|
||||
InsecureForTests: true,
|
||||
ForceHTTP: options.TLSCertificates == nil,
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
AutoImportTemplates: options.AutoImportTemplates,
|
||||
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
|
||||
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = coderAPI.Close()
|
||||
})
|
||||
srv.Config.Handler = coderAPI.Handler
|
||||
AutoImportTemplates: options.AutoImportTemplates,
|
||||
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
|
||||
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
|
||||
DeploymentFlags: options.DeploymentFlags,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithAPI constructs an in-memory API instance and returns a client to talk to it.
|
||||
// Most tests never need a reference to the API, but AuthorizationTest in this module uses it.
|
||||
// Do not expose the API or wrath shall descend upon thee.
|
||||
func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
|
||||
if options == nil {
|
||||
options = &Options{}
|
||||
}
|
||||
setHandler, cancelFunc, newOptions := NewOptions(t, options)
|
||||
// We set the handler after server creation for the access URL.
|
||||
coderAPI := coderd.New(newOptions)
|
||||
setHandler(coderAPI.RootHandler)
|
||||
var provisionerCloser io.Closer = nopcloser{}
|
||||
if options.IncludeProvisionerDaemon {
|
||||
provisionerCloser = NewProvisionerDaemon(t, coderAPI)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cancelFunc()
|
||||
_ = provisionerCloser.Close()
|
||||
_ = coderAPI.Close()
|
||||
})
|
||||
|
||||
return codersdk.New(serverURL), provisionerCloser, coderAPI
|
||||
return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI
|
||||
}
|
||||
|
||||
// NewProvisionerDaemon launches a provisionerd instance configured to work
|
||||
@@ -391,12 +427,13 @@ func createAnotherUserRetry(t *testing.T, client *codersdk.Client, organizationI
|
||||
// with the responses provided. It uses the "echo" provisioner for compatibility
|
||||
// with testing.
|
||||
func CreateTemplateVersion(t *testing.T, client *codersdk.Client, organizationID uuid.UUID, res *echo.Responses) codersdk.TemplateVersion {
|
||||
t.Helper()
|
||||
data, err := echo.Tar(res)
|
||||
require.NoError(t, err)
|
||||
file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data)
|
||||
require.NoError(t, err)
|
||||
templateVersion, err := client.CreateTemplateVersion(context.Background(), organizationID, codersdk.CreateTemplateVersionRequest{
|
||||
StorageSource: file.Hash,
|
||||
FileID: file.ID,
|
||||
StorageMethod: codersdk.ProvisionerStorageMethodFile,
|
||||
Provisioner: codersdk.ProvisionerTypeEcho,
|
||||
})
|
||||
@@ -444,7 +481,7 @@ func UpdateTemplateVersion(t *testing.T, client *codersdk.Client, organizationID
|
||||
require.NoError(t, err)
|
||||
templateVersion, err := client.CreateTemplateVersion(context.Background(), organizationID, codersdk.CreateTemplateVersionRequest{
|
||||
TemplateID: templateID,
|
||||
StorageSource: file.Hash,
|
||||
FileID: file.ID,
|
||||
StorageMethod: codersdk.ProvisionerStorageMethodFile,
|
||||
Provisioner: codersdk.ProvisionerTypeEcho,
|
||||
})
|
||||
@@ -462,7 +499,7 @@ func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid
|
||||
var err error
|
||||
templateVersion, err = client.TemplateVersion(context.Background(), version)
|
||||
return assert.NoError(t, err) && templateVersion.Job.CompletedAt != nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
return templateVersion
|
||||
}
|
||||
|
||||
@@ -480,18 +517,22 @@ func AwaitWorkspaceBuildJob(t *testing.T, client *codersdk.Client, build uuid.UU
|
||||
}
|
||||
|
||||
// AwaitWorkspaceAgents waits for all resources with agents to be connected.
|
||||
func AwaitWorkspaceAgents(t *testing.T, client *codersdk.Client, build uuid.UUID) []codersdk.WorkspaceResource {
|
||||
func AwaitWorkspaceAgents(t *testing.T, client *codersdk.Client, workspaceID uuid.UUID) []codersdk.WorkspaceResource {
|
||||
t.Helper()
|
||||
|
||||
t.Logf("waiting for workspace agents (build %s)", build)
|
||||
t.Logf("waiting for workspace agents (workspace %s)", workspaceID)
|
||||
var resources []codersdk.WorkspaceResource
|
||||
require.Eventually(t, func() bool {
|
||||
var err error
|
||||
resources, err = client.WorkspaceResourcesByBuild(context.Background(), build)
|
||||
workspace, err := client.Workspace(context.Background(), workspaceID)
|
||||
if !assert.NoError(t, err) {
|
||||
return false
|
||||
}
|
||||
for _, resource := range resources {
|
||||
if workspace.LatestBuild.Job.CompletedAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, resource := range workspace.LatestBuild.Resources {
|
||||
for _, agent := range resource.Agents {
|
||||
if agent.Status != codersdk.WorkspaceAgentConnected {
|
||||
t.Logf("agent %s not connected yet", agent.Name)
|
||||
@@ -499,6 +540,8 @@ func AwaitWorkspaceAgents(t *testing.T, client *codersdk.Client, build uuid.UUID
|
||||
}
|
||||
}
|
||||
}
|
||||
resources = workspace.LatestBuild.Resources
|
||||
|
||||
return true
|
||||
}, testutil.WaitLong, testutil.IntervalFast)
|
||||
return resources
|
||||
@@ -518,7 +561,7 @@ func CreateWorkspace(t *testing.T, client *codersdk.Client, organization uuid.UU
|
||||
for _, mutator := range mutators {
|
||||
mutator(&req)
|
||||
}
|
||||
workspace, err := client.CreateWorkspace(context.Background(), organization, req)
|
||||
workspace, err := client.CreateWorkspace(context.Background(), organization, codersdk.Me, req)
|
||||
require.NoError(t, err)
|
||||
return workspace
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func TestNew(t *testing.T) {
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
|
||||
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
|
||||
coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
|
||||
_, _ = coderdtest.NewGoogleInstanceIdentity(t, "example", false)
|
||||
_, _ = coderdtest.NewAWSInstanceIdentity(t, "an-instance")
|
||||
}
|
||||
|
||||
+3
-3
@@ -23,7 +23,7 @@ func (api *API) logReportCSPViolations(rw http.ResponseWriter, r *http.Request)
|
||||
err := dec.Decode(&v)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, "csp violation", slog.Error(err))
|
||||
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Failed to read body, invalid json.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
@@ -34,7 +34,7 @@ func (api *API) logReportCSPViolations(rw http.ResponseWriter, r *http.Request)
|
||||
for k, v := range v.Report {
|
||||
fields = append(fields, slog.F(k, v))
|
||||
}
|
||||
api.Logger.Warn(ctx, "csp violation", fields...)
|
||||
api.Logger.Debug(ctx, "csp violation", fields...)
|
||||
|
||||
httpapi.Write(rw, http.StatusOK, "ok")
|
||||
httpapi.Write(ctx, rw, http.StatusOK, "ok")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+27
-5
@@ -12,7 +12,9 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
@@ -20,7 +22,10 @@ import (
|
||||
// It extends the generated interface to add transaction support.
|
||||
type Store interface {
|
||||
querier
|
||||
// customQuerier contains custom queries that are not generated.
|
||||
customQuerier
|
||||
|
||||
Ping(ctx context.Context) (time.Duration, error)
|
||||
InTx(func(Store) error) error
|
||||
}
|
||||
|
||||
@@ -30,24 +35,41 @@ type DBTX interface {
|
||||
PrepareContext(context.Context, string) (*sql.Stmt, error)
|
||||
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
|
||||
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||
}
|
||||
|
||||
// New creates a new database store using a SQL database connection.
|
||||
func New(sdb *sql.DB) Store {
|
||||
dbx := sqlx.NewDb(sdb, "postgres")
|
||||
return &sqlQuerier{
|
||||
db: sdb,
|
||||
sdb: sdb,
|
||||
db: dbx,
|
||||
sdb: dbx,
|
||||
}
|
||||
}
|
||||
|
||||
// queries encompasses both are sqlc generated
|
||||
// queries and our custom queries.
|
||||
type querier interface {
|
||||
sqlcQuerier
|
||||
customQuerier
|
||||
}
|
||||
|
||||
type sqlQuerier struct {
|
||||
sdb *sql.DB
|
||||
sdb *sqlx.DB
|
||||
db DBTX
|
||||
}
|
||||
|
||||
// Ping returns the time it takes to ping the database.
|
||||
func (q *sqlQuerier) Ping(ctx context.Context) (time.Duration, error) {
|
||||
start := time.Now()
|
||||
err := q.sdb.PingContext(ctx)
|
||||
return time.Since(start), err
|
||||
}
|
||||
|
||||
// InTx performs database operations inside a transaction.
|
||||
func (q *sqlQuerier) InTx(function func(Store) error) error {
|
||||
if _, ok := q.db.(*sql.Tx); ok {
|
||||
if _, ok := q.db.(*sqlx.Tx); ok {
|
||||
// If the current inner "db" is already a transaction, we just reuse it.
|
||||
// We do not need to handle commit/rollback as the outer tx will handle
|
||||
// that.
|
||||
@@ -58,7 +80,7 @@ func (q *sqlQuerier) InTx(function func(Store) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
transaction, err := q.sdb.Begin()
|
||||
transaction, err := q.sdb.BeginTxx(context.Background(), nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,12 +4,15 @@ package database_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/migrations"
|
||||
"github.com/coder/coder/coderd/database/postgres"
|
||||
)
|
||||
|
||||
func TestNestedInTx(t *testing.T) {
|
||||
@@ -20,7 +23,7 @@ func TestNestedInTx(t *testing.T) {
|
||||
|
||||
uid := uuid.New()
|
||||
sqlDB := testSQLDB(t)
|
||||
err := database.MigrateUp(sqlDB)
|
||||
err := migrations.Up(sqlDB)
|
||||
require.NoError(t, err, "migrations")
|
||||
|
||||
db := database.New(sqlDB)
|
||||
@@ -48,3 +51,17 @@ func TestNestedInTx(t *testing.T) {
|
||||
require.NoError(t, err, "user exists")
|
||||
require.Equal(t, uid, user.ID, "user id expected")
|
||||
}
|
||||
|
||||
func testSQLDB(t testing.TB) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
connection, closeFn, err := postgres.Open()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(closeFn)
|
||||
|
||||
db, err := sql.Open("postgres", connection)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user