diff --git a/cmd/ssh.go b/cmd/ssh.go index e5cb5a628..bd0959f62 100644 --- a/cmd/ssh.go +++ b/cmd/ssh.go @@ -38,6 +38,8 @@ type SSHCmd struct { ForwardPortsTimeout string ForwardPorts []string ReverseForwardPorts []string + SendEnvVars []string + SetEnvVars []string Stdio bool JumpContainer bool @@ -87,6 +89,8 @@ func NewSSHCmd(flags *flags.GlobalFlags) *cobra.Command { sshCmd.Flags().StringArrayVarP(&cmd.ForwardPorts, "forward-ports", "L", []string{}, "Specifies that connections to the given TCP port or Unix socket on the local (client) host are to be forwarded to the given host and port, or Unix socket, on the remote side.") sshCmd.Flags().StringArrayVarP(&cmd.ReverseForwardPorts, "reverse-forward-ports", "R", []string{}, "Specifies that connections to the given TCP port or Unix socket on the local (client) host are to be reverse forwarded to the given host and port, or Unix socket, on the remote side.") + sshCmd.Flags().StringArrayVarP(&cmd.SendEnvVars, "send-env", "", []string{}, "Specifies which local env variables shall be sent to the container.") + sshCmd.Flags().StringArrayVarP(&cmd.SetEnvVars, "set-env", "", []string{}, "Specifies env variables to be set in the container.") sshCmd.Flags().StringVar(&cmd.ForwardPortsTimeout, "forward-ports-timeout", "", "Specifies the timeout after which the command should terminate when the ports are unused.") sshCmd.Flags().StringVar(&cmd.Command, "command", "", "The command to execute within the workspace") sshCmd.Flags().StringVar(&cmd.User, "user", "", "The user of the workspace to use") @@ -211,6 +215,25 @@ func startWait( } } +func (cmd *SSHCmd) retrieveEnVars() (map[string]string, error) { + envVars := make(map[string]string) + for _, envVar := range cmd.SendEnvVars { + envVarValue, exist := os.LookupEnv(envVar) + if exist { + envVars[envVar] = envVarValue + } + } + for _, envVar := range cmd.SetEnvVars { + parts := strings.Split(envVar, "=") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid env var: %s", envVar) + } + envVars[parts[0]] = parts[1] + } + + return envVars, nil +} + func (cmd *SSHCmd) jumpContainer( ctx context.Context, devPodConfig *config.Config, @@ -231,6 +254,11 @@ func (cmd *SSHCmd) jumpContainer( return err } + envVars, err := cmd.retrieveEnVars() + if err != nil { + return err + } + // tunnel to container return tunnel.NewContainerTunnel(client, cmd.Proxy, log). Run(ctx, func(ctx context.Context, containerClient *ssh.Client) error { @@ -239,7 +267,7 @@ func (cmd *SSHCmd) jumpContainer( // start ssh tunnel return cmd.startTunnel(ctx, devPodConfig, containerClient, client.Workspace(), log) - }, devPodConfig) + }, devPodConfig, envVars) } func (cmd *SSHCmd) forwardTimeout(log log.Logger) (time.Duration, error) { @@ -394,6 +422,11 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config, command = fmt.Sprintf("su -c \"%s\" '%s'", command, cmd.User) } + envVars, err := cmd.retrieveEnVars() + if err != nil { + return err + } + // Traffic is coming in from the outside, we need to forward it to the container if cmd.Proxy || cmd.Stdio { if cmd.Proxy { @@ -404,7 +437,7 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config, }() } - return devssh.Run(ctx, containerClient, command, os.Stdin, os.Stdout, writer) + return devssh.Run(ctx, containerClient, command, os.Stdin, os.Stdout, writer, envVars) } return machine.StartSSHSession( @@ -414,7 +447,7 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config, !cmd.Proxy && cmd.AgentForwarding && devPodConfig.ContextOption(config.ContextOptionSSHAgentForwarding) == "true", func(ctx context.Context, stdin io.Reader, stdout io.Writer, stderr io.Writer) error { - return devssh.Run(ctx, containerClient, command, stdin, stdout, stderr) + return devssh.Run(ctx, containerClient, command, stdin, stdout, stderr, envVars) }, writer, ) @@ -577,7 +610,7 @@ func (cmd *SSHCmd) setupGPGAgent( command = fmt.Sprintf("su -c \"%s\" '%s'", command, cmd.User) } - return devssh.Run(ctx, containerClient, command, nil, writer, writer) + return devssh.Run(ctx, containerClient, command, nil, writer, writer, nil) } func mergeDevPodSshOptions(cmd *SSHCmd) error { @@ -612,7 +645,7 @@ func startWorkspaceCredentialServer(ctx context.Context, client *ssh.Client, use args = append(args, "--runner") command = fmt.Sprintf("%s %s", command, strings.Join(args, " ")) - if err := devssh.Run(ctx, client, command, stdin, stdout, writer); err != nil { + if err := devssh.Run(ctx, client, command, stdin, stdout, writer, nil); err != nil { return fmt.Errorf("run credentials server: %w", err) } diff --git a/pkg/devcontainer/sshtunnel/sshtunnel.go b/pkg/devcontainer/sshtunnel/sshtunnel.go index ff6aff97e..8c64f09b0 100644 --- a/pkg/devcontainer/sshtunnel/sshtunnel.go +++ b/pkg/devcontainer/sshtunnel/sshtunnel.go @@ -124,7 +124,7 @@ func ExecuteCommand( writer := log.Writer(logrus.InfoLevel, false) defer writer.Close() - err = devssh.Run(ctx, sshClient, command, gRPCConnStdinReader, gRPCConnStdoutWriter, writer) + err = devssh.Run(ctx, sshClient, command, gRPCConnStdinReader, gRPCConnStdoutWriter, writer, nil) if err != nil { errChan <- errors.Wrap(err, "run agent command") } else { diff --git a/pkg/gpg/gpg_forwarding.go b/pkg/gpg/gpg_forwarding.go index b8132eec4..89c12093f 100644 --- a/pkg/gpg/gpg_forwarding.go +++ b/pkg/gpg/gpg_forwarding.go @@ -40,7 +40,7 @@ func IsGpgTunnelRunning( // capture the output, if it's empty it means we don't have gpg-forwarding var out bytes.Buffer - err := devssh.Run(ctx, client, command, nil, &out, writer) + err := devssh.Run(ctx, client, command, nil, &out, writer, nil) return err == nil && len(out.Bytes()) > 1 } diff --git a/pkg/ssh/helper.go b/pkg/ssh/helper.go index 6dd785433..ff547b4b0 100644 --- a/pkg/ssh/helper.go +++ b/pkg/ssh/helper.go @@ -90,13 +90,20 @@ func ConfigFromKeyBytes(keyBytes []byte) (*ssh.ClientConfig, error) { return clientConfig, nil } -func Run(ctx context.Context, client *ssh.Client, command string, stdin io.Reader, stdout io.Writer, stderr io.Writer) error { +func Run(ctx context.Context, client *ssh.Client, command string, stdin io.Reader, stdout io.Writer, stderr io.Writer, envVars map[string]string) error { sess, err := client.NewSession() if err != nil { return err } defer sess.Close() + for k, v := range envVars { + err = sess.Setenv(k, v) + if err != nil { + return err + } + } + exit := make(chan struct{}) defer close(exit) go func() { diff --git a/pkg/tunnel/container.go b/pkg/tunnel/container.go index e1eddf430..0dd4cad90 100644 --- a/pkg/tunnel/container.go +++ b/pkg/tunnel/container.go @@ -38,7 +38,7 @@ type ContainerHandler struct { type Handler func(ctx context.Context, containerClient *ssh.Client) error -func (c *ContainerHandler) Run(ctx context.Context, handler Handler, cfg *config.Config) error { +func (c *ContainerHandler) Run(ctx context.Context, handler Handler, cfg *config.Config, envVars map[string]string) error { if handler == nil { return nil } @@ -118,7 +118,7 @@ func (c *ContainerHandler) Run(ctx context.Context, handler Handler, cfg *config } // wait until we are done - if err := c.runRunInContainer(cancelCtx, sshClient, handler); err != nil { + if err := c.runRunInContainer(cancelCtx, sshClient, handler, envVars); err != nil { containerChan <- fmt.Errorf("run in container: %w", err) } else { containerChan <- nil @@ -164,7 +164,7 @@ func (c *ContainerHandler) updateConfig(ctx context.Context, sshClient *ssh.Clie } c.log.Debugf("Run command in container: %s", command) - err = devssh.Run(ctx, sshClient, command, nil, buf, buf) + err = devssh.Run(ctx, sshClient, command, nil, buf, buf, nil) if err != nil { c.log.Errorf("Error updating remote workspace: %s%v", buf.String(), err) } else { @@ -174,7 +174,7 @@ func (c *ContainerHandler) updateConfig(ctx context.Context, sshClient *ssh.Clie } } -func (c *ContainerHandler) runRunInContainer(ctx context.Context, sshClient *ssh.Client, runInContainer Handler) error { +func (c *ContainerHandler) runRunInContainer(ctx context.Context, sshClient *ssh.Client, runInContainer Handler, envVars map[string]string) error { // compress info workspaceInfo, _, err := c.client.AgentInfo(provider.CLIOptions{Proxy: c.proxy}) if err != nil { @@ -211,7 +211,7 @@ func (c *ContainerHandler) runRunInContainer(ctx context.Context, sshClient *ssh if c.log.GetLevel() == logrus.DebugLevel { command += " --debug" } - err = devssh.Run(cancelCtx, sshClient, command, stdinReader, stdoutWriter, writer) + err = devssh.Run(cancelCtx, sshClient, command, stdinReader, stdoutWriter, writer, envVars) if err != nil { c.log.Errorf("Error tunneling to container: %v", err) return diff --git a/pkg/tunnel/services.go b/pkg/tunnel/services.go index ef528cb65..d66b9d8a2 100644 --- a/pkg/tunnel/services.go +++ b/pkg/tunnel/services.go @@ -131,7 +131,7 @@ func RunInContainer( command += " --debug" } - err = devssh.Run(cancelCtx, containerClient, command, stdinReader, stdoutWriter, writer) + err = devssh.Run(cancelCtx, containerClient, command, stdinReader, stdoutWriter, writer, nil) if err != nil { return err } @@ -147,7 +147,7 @@ func RunInContainer( func forwardDevContainerPorts(ctx context.Context, containerClient *ssh.Client, extraPorts []string, exitAfterTimeout time.Duration, log log.Logger) ([]string, error) { stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - err := devssh.Run(ctx, containerClient, "cat "+setup.ResultLocation, nil, stdout, stderr) + err := devssh.Run(ctx, containerClient, "cat "+setup.ResultLocation, nil, stdout, stderr, nil) if err != nil { return nil, fmt.Errorf("retrieve container result: %s\n%s%w", stdout.String(), stderr.String(), err) }