// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssh import ( "errors" "io" "sync" ) // A Channel is an ordered, reliable, duplex stream that is multiplexed over an // SSH connection. type Channel interface { // Accept accepts the channel creation request. Accept() error // Reject rejects the channel creation request. After calling this, no // other methods on the Channel may be called. If they are then the // peer is likely to signal a protocol error and drop the connection. Reject(reason RejectionReason, message string) error // Read may return a ChannelRequest as an error. Read(data []byte) (int, error) Write(data []byte) (int, error) Close() error // AckRequest either sends an ack or nack to the channel request. AckRequest(ok bool) error // ChannelType returns the type of the channel, as supplied by the // client. ChannelType() string // ExtraData returns the arbitary payload for this channel, as supplied // by the client. This data is specific to the channel type. ExtraData() []byte } // ChannelRequest represents a request sent on a channel, outside of the normal // stream of bytes. It may result from calling Read on a Channel. type ChannelRequest struct { Request string WantReply bool Payload []byte } func (c ChannelRequest) Error() string { return "channel request received" } // RejectionReason is an enumeration used when rejecting channel creation // requests. See RFC 4254, section 5.1. type RejectionReason int const ( Prohibited RejectionReason = iota + 1 ConnectionFailed UnknownChannelType ResourceShortage ) type channel struct { // immutable once created chanType string extraData []byte theyClosed bool theySentEOF bool weClosed bool dead bool serverConn *ServerConn myId, theirId uint32 myWindow, theirWindow uint32 maxPacketSize uint32 err error pendingRequests []ChannelRequest pendingData []byte head, length int // This lock is inferior to serverConn.lock lock sync.Mutex cond *sync.Cond } func (c *channel) Accept() error { c.serverConn.lock.Lock() defer c.serverConn.lock.Unlock() if c.serverConn.err != nil { return c.serverConn.err } confirm := channelOpenConfirmMsg{ PeersId: c.theirId, MyId: c.myId, MyWindow: c.myWindow, MaxPacketSize: c.maxPacketSize, } return c.serverConn.writePacket(marshal(msgChannelOpenConfirm, confirm)) } func (c *channel) Reject(reason RejectionReason, message string) error { c.serverConn.lock.Lock() defer c.serverConn.lock.Unlock() if c.serverConn.err != nil { return c.serverConn.err } reject := channelOpenFailureMsg{ PeersId: c.theirId, Reason: uint32(reason), Message: message, Language: "en", } return c.serverConn.writePacket(marshal(msgChannelOpenFailure, reject)) } func (c *channel) handlePacket(packet interface{}) { c.lock.Lock() defer c.lock.Unlock() switch packet := packet.(type) { case *channelRequestMsg: req := ChannelRequest{ Request: packet.Request, WantReply: packet.WantReply, Payload: packet.RequestSpecificData, } c.pendingRequests = append(c.pendingRequests, req) c.cond.Signal() case *channelCloseMsg: c.theyClosed = true c.cond.Signal() case *channelEOFMsg: c.theySentEOF = true c.cond.Signal() default: panic("unknown packet type") } } func (c *channel) handleData(data []byte) { c.lock.Lock() defer c.lock.Unlock() // The other side should never send us more than our window. if len(data)+c.length > len(c.pendingData) { // TODO(agl): we should tear down the channel with a protocol // error. return } c.myWindow -= uint32(len(data)) for i := 0; i < 2; i++ { tail := c.head + c.length if tail > len(c.pendingData) { tail -= len(c.pendingData) } n := copy(c.pendingData[tail:], data) data = data[n:] c.length += n } c.cond.Signal() } func (c *channel) Read(data []byte) (n int, err error) { c.lock.Lock() defer c.lock.Unlock() if c.err != nil { return 0, c.err } if c.myWindow <= uint32(len(c.pendingData))/2 { packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{ PeersId: c.theirId, AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow, }) if err := c.serverConn.writePacket(packet); err != nil { return 0, err } } for { if c.theySentEOF || c.theyClosed || c.dead { return 0, io.EOF } if len(c.pendingRequests) > 0 { req := c.pendingRequests[0] if len(c.pendingRequests) == 1 { c.pendingRequests = nil } else { oldPendingRequests := c.pendingRequests c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1) copy(c.pendingRequests, oldPendingRequests[1:]) } return 0, req } if c.length > 0 { tail := c.head + c.length if tail > len(c.pendingData) { tail -= len(c.pendingData) } n = copy(data, c.pendingData[c.head:tail]) c.head += n c.length -= n if c.head == len(c.pendingData) { c.head = 0 } return } c.cond.Wait() } panic("unreachable") } func (c *channel) Write(data []byte) (n int, err error) { for len(data) > 0 { c.lock.Lock() if c.dead || c.weClosed { return 0, io.EOF } if c.theirWindow == 0 { c.cond.Wait() continue } c.lock.Unlock() todo := data if uint32(len(todo)) > c.theirWindow { todo = todo[:c.theirWindow] } packet := make([]byte, 1+4+4+len(todo)) packet[0] = msgChannelData packet[1] = byte(c.theirId >> 24) packet[2] = byte(c.theirId >> 16) packet[3] = byte(c.theirId >> 8) packet[4] = byte(c.theirId) packet[5] = byte(len(todo) >> 24) packet[6] = byte(len(todo) >> 16) packet[7] = byte(len(todo) >> 8) packet[8] = byte(len(todo)) copy(packet[9:], todo) c.serverConn.lock.Lock() if err = c.serverConn.writePacket(packet); err != nil { c.serverConn.lock.Unlock() return } c.serverConn.lock.Unlock() n += len(todo) data = data[len(todo):] } return } func (c *channel) Close() error { c.serverConn.lock.Lock() defer c.serverConn.lock.Unlock() if c.serverConn.err != nil { return c.serverConn.err } if c.weClosed { return errors.New("ssh: channel already closed") } c.weClosed = true closeMsg := channelCloseMsg{ PeersId: c.theirId, } return c.serverConn.writePacket(marshal(msgChannelClose, closeMsg)) } func (c *channel) AckRequest(ok bool) error { c.serverConn.lock.Lock() defer c.serverConn.lock.Unlock() if c.serverConn.err != nil { return c.serverConn.err } if ok { ack := channelRequestSuccessMsg{ PeersId: c.theirId, } return c.serverConn.writePacket(marshal(msgChannelSuccess, ack)) } else { ack := channelRequestFailureMsg{ PeersId: c.theirId, } return c.serverConn.writePacket(marshal(msgChannelFailure, ack)) } panic("unreachable") } func (c *channel) ChannelType() string { return c.chanType } func (c *channel) ExtraData() []byte { return c.extraData }