Skip to content

Commit

Permalink
Add send-env and set-env options to ssh command
Browse files Browse the repository at this point in the history
  • Loading branch information
aacebedo authored and pascalbreuninger committed Oct 10, 2024
1 parent 0166f3d commit 4e94be7
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 15 deletions.
43 changes: 38 additions & 5 deletions cmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type SSHCmd struct {
ForwardPortsTimeout string
ForwardPorts []string
ReverseForwardPorts []string
SendEnvVars []string
SetEnvVars []string

Stdio bool
JumpContainer bool
Expand Down Expand Up @@ -90,6 +92,8 @@ func NewSSHCmd(f *flags.GlobalFlags) *cobra.Command {
dpFlags.SetGitCredentialsFlags(sshCmd.Flags(), &cmd.GitCredentialsFlags)
sshCmd.Flags().StringArrayVarP(&cmd.ForwardPorts, "forward-ports", "L", []string{}, "Specifies that connections to the given TCP port or Unix socket on the local (client) host are to be forwarded to the given host and port, or Unix socket, on the remote side.")
sshCmd.Flags().StringArrayVarP(&cmd.ReverseForwardPorts, "reverse-forward-ports", "R", []string{}, "Specifies that connections to the given TCP port or Unix socket on the local (client) host are to be reverse forwarded to the given host and port, or Unix socket, on the remote side.")
sshCmd.Flags().StringArrayVarP(&cmd.SendEnvVars, "send-env", "", []string{}, "Specifies which local env variables shall be sent to the container.")
sshCmd.Flags().StringArrayVarP(&cmd.SetEnvVars, "set-env", "", []string{}, "Specifies env variables to be set in the container.")
sshCmd.Flags().StringVar(&cmd.ForwardPortsTimeout, "forward-ports-timeout", "", "Specifies the timeout after which the command should terminate when the ports are unused.")
sshCmd.Flags().StringVar(&cmd.Command, "command", "", "The command to execute within the workspace")
sshCmd.Flags().StringVar(&cmd.User, "user", "", "The user of the workspace to use")
Expand Down Expand Up @@ -215,6 +219,25 @@ func startWait(
}
}

func (cmd *SSHCmd) retrieveEnVars() (map[string]string, error) {
envVars := make(map[string]string)
for _, envVar := range cmd.SendEnvVars {
envVarValue, exist := os.LookupEnv(envVar)
if exist {
envVars[envVar] = envVarValue
}
}
for _, envVar := range cmd.SetEnvVars {
parts := strings.Split(envVar, "=")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid env var: %s", envVar)
}
envVars[parts[0]] = parts[1]
}

return envVars, nil
}

func (cmd *SSHCmd) jumpContainer(
ctx context.Context,
devPodConfig *config.Config,
Expand All @@ -235,6 +258,11 @@ func (cmd *SSHCmd) jumpContainer(
return err
}

envVars, err := cmd.retrieveEnVars()
if err != nil {
return err
}

// tunnel to container
return tunnel.NewContainerTunnel(client, cmd.Proxy, log).
Run(ctx, func(ctx context.Context, containerClient *ssh.Client) error {
Expand All @@ -243,7 +271,7 @@ func (cmd *SSHCmd) jumpContainer(

// start ssh tunnel
return cmd.startTunnel(ctx, devPodConfig, containerClient, client.Workspace(), log)
}, devPodConfig)
}, devPodConfig, envVars)
}

func (cmd *SSHCmd) forwardTimeout(log log.Logger) (time.Duration, error) {
Expand Down Expand Up @@ -398,6 +426,11 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config,
command = fmt.Sprintf("su -c \"%s\" '%s'", command, cmd.User)
}

envVars, err := cmd.retrieveEnVars()
if err != nil {
return err
}

// Traffic is coming in from the outside, we need to forward it to the container
if cmd.Proxy || cmd.Stdio {
if cmd.Proxy {
Expand All @@ -408,7 +441,7 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config,
}()
}

return devssh.Run(ctx, containerClient, command, os.Stdin, os.Stdout, writer)
return devssh.Run(ctx, containerClient, command, os.Stdin, os.Stdout, writer, envVars)
}

return machine.StartSSHSession(
Expand All @@ -418,7 +451,7 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config,
!cmd.Proxy && cmd.AgentForwarding &&
devPodConfig.ContextOption(config.ContextOptionSSHAgentForwarding) == "true",
func(ctx context.Context, stdin io.Reader, stdout io.Writer, stderr io.Writer) error {
return devssh.Run(ctx, containerClient, command, stdin, stdout, stderr)
return devssh.Run(ctx, containerClient, command, stdin, stdout, stderr, envVars)
},
writer,
)
Expand Down Expand Up @@ -581,7 +614,7 @@ func (cmd *SSHCmd) setupGPGAgent(
command = fmt.Sprintf("su -c \"%s\" '%s'", command, cmd.User)
}

return devssh.Run(ctx, containerClient, command, nil, writer, writer)
return devssh.Run(ctx, containerClient, command, nil, writer, writer, nil)
}

func mergeDevPodSshOptions(cmd *SSHCmd) error {
Expand Down Expand Up @@ -616,7 +649,7 @@ func startWorkspaceCredentialServer(ctx context.Context, client *ssh.Client, use
args = append(args, "--runner")
command = fmt.Sprintf("%s %s", command, strings.Join(args, " "))

if err := devssh.Run(ctx, client, command, stdin, stdout, writer); err != nil {
if err := devssh.Run(ctx, client, command, stdin, stdout, writer, nil); err != nil {
return fmt.Errorf("run credentials server: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/devcontainer/sshtunnel/sshtunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func ExecuteCommand(
writer := log.Writer(logrus.InfoLevel, false)
defer writer.Close()

err = devssh.Run(ctx, sshClient, command, gRPCConnStdinReader, gRPCConnStdoutWriter, writer)
err = devssh.Run(ctx, sshClient, command, gRPCConnStdinReader, gRPCConnStdoutWriter, writer, nil)
if err != nil {
errChan <- errors.Wrap(err, "run agent command")
} else {
Expand Down
2 changes: 1 addition & 1 deletion pkg/gpg/gpg_forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func IsGpgTunnelRunning(

// capture the output, if it's empty it means we don't have gpg-forwarding
var out bytes.Buffer
err := devssh.Run(ctx, client, command, nil, &out, writer)
err := devssh.Run(ctx, client, command, nil, &out, writer, nil)

return err == nil && len(out.Bytes()) > 1
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/ssh/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,20 @@ func ConfigFromKeyBytes(keyBytes []byte) (*ssh.ClientConfig, error) {
return clientConfig, nil
}

func Run(ctx context.Context, client *ssh.Client, command string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error {
func Run(ctx context.Context, client *ssh.Client, command string, stdin io.Reader, stdout io.Writer, stderr io.Writer, envVars map[string]string) error {
sess, err := client.NewSession()
if err != nil {
return err
}
defer sess.Close()

for k, v := range envVars {
err = sess.Setenv(k, v)
if err != nil {
return err
}
}

exit := make(chan struct{})
defer close(exit)
go func() {
Expand Down
10 changes: 5 additions & 5 deletions pkg/tunnel/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type ContainerHandler struct {

type Handler func(ctx context.Context, containerClient *ssh.Client) error

func (c *ContainerHandler) Run(ctx context.Context, handler Handler, cfg *config.Config) error {
func (c *ContainerHandler) Run(ctx context.Context, handler Handler, cfg *config.Config, envVars map[string]string) error {
if handler == nil {
return nil
}
Expand Down Expand Up @@ -118,7 +118,7 @@ func (c *ContainerHandler) Run(ctx context.Context, handler Handler, cfg *config
}

// wait until we are done
if err := c.runRunInContainer(cancelCtx, sshClient, handler); err != nil {
if err := c.runRunInContainer(cancelCtx, sshClient, handler, envVars); err != nil {
containerChan <- fmt.Errorf("run in container: %w", err)
} else {
containerChan <- nil
Expand Down Expand Up @@ -164,7 +164,7 @@ func (c *ContainerHandler) updateConfig(ctx context.Context, sshClient *ssh.Clie
}

c.log.Debugf("Run command in container: %s", command)
err = devssh.Run(ctx, sshClient, command, nil, buf, buf)
err = devssh.Run(ctx, sshClient, command, nil, buf, buf, nil)
if err != nil {
c.log.Errorf("Error updating remote workspace: %s%v", buf.String(), err)
} else {
Expand All @@ -174,7 +174,7 @@ func (c *ContainerHandler) updateConfig(ctx context.Context, sshClient *ssh.Clie
}
}

func (c *ContainerHandler) runRunInContainer(ctx context.Context, sshClient *ssh.Client, runInContainer Handler) error {
func (c *ContainerHandler) runRunInContainer(ctx context.Context, sshClient *ssh.Client, runInContainer Handler, envVars map[string]string) error {
// compress info
workspaceInfo, _, err := c.client.AgentInfo(provider.CLIOptions{Proxy: c.proxy})
if err != nil {
Expand Down Expand Up @@ -211,7 +211,7 @@ func (c *ContainerHandler) runRunInContainer(ctx context.Context, sshClient *ssh
if c.log.GetLevel() == logrus.DebugLevel {
command += " --debug"
}
err = devssh.Run(cancelCtx, sshClient, command, stdinReader, stdoutWriter, writer)
err = devssh.Run(cancelCtx, sshClient, command, stdinReader, stdoutWriter, writer, envVars)
if err != nil {
c.log.Errorf("Error tunneling to container: %v", err)
return
Expand Down
4 changes: 2 additions & 2 deletions pkg/tunnel/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func RunInContainer(
command += " --debug"
}

err = devssh.Run(cancelCtx, containerClient, command, stdinReader, stdoutWriter, writer)
err = devssh.Run(cancelCtx, containerClient, command, stdinReader, stdoutWriter, writer, nil)
if err != nil {
return err
}
Expand All @@ -147,7 +147,7 @@ func RunInContainer(
func forwardDevContainerPorts(ctx context.Context, containerClient *ssh.Client, extraPorts []string, exitAfterTimeout time.Duration, log log.Logger) ([]string, error) {
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
err := devssh.Run(ctx, containerClient, "cat "+setup.ResultLocation, nil, stdout, stderr)
err := devssh.Run(ctx, containerClient, "cat "+setup.ResultLocation, nil, stdout, stderr, nil)
if err != nil {
return nil, fmt.Errorf("retrieve container result: %s\n%s%w", stdout.String(), stderr.String(), err)
}
Expand Down

0 comments on commit 4e94be7

Please sign in to comment.