Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions internal/wintun/wintun_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
package wintun

import (
"runtime"
"log"
"syscall"
"unsafe"

Expand All @@ -30,6 +30,7 @@ var (
)

func closeAdapter(wintun *Adapter) {
log.Println("[tomi] closeAdapter")
syscall.SyscallN(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
}

Expand All @@ -39,6 +40,7 @@ func closeAdapter(wintun *Adapter) {
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
// and hence a new NLA entry is created for each new adapter.
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
log.Println("[tomi] CreateAdapter")
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {
Expand All @@ -55,12 +57,13 @@ func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID)
return
}
wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, closeAdapter)
//runtime.SetFinalizer(wintun, closeAdapter)
return
}

// OpenAdapter opens an existing Wintun adapter by name.
func OpenAdapter(name string) (wintun *Adapter, err error) {
log.Println("[tomi] OpenAdapter")
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {
Expand All @@ -72,13 +75,15 @@ func OpenAdapter(name string) (wintun *Adapter, err error) {
return
}
wintun = &Adapter{handle: r0}
runtime.SetFinalizer(wintun, closeAdapter)
//runtime.SetFinalizer(wintun, closeAdapter)
return
}

// Close closes a Wintun adapter.
func (wintun *Adapter) Close() (err error) {
runtime.SetFinalizer(wintun, nil)
log.Println("[tomi] CloseAdapter")

//runtime.SetFinalizer(wintun, nil)
r1, _, e1 := syscall.Syscall(procWintunCloseAdapter.Addr(), 1, wintun.handle, 0, 0)
if r1 == 0 {
err = e1
Expand Down
46 changes: 42 additions & 4 deletions tun_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/md5"
"errors"
"fmt"
"log"
"math"
"net"
"net/netip"
Expand Down Expand Up @@ -38,18 +39,49 @@ type NativeTun struct {
fwpmSession uintptr
}

var logPrefix = "[tomi][tun]"

func New(options Options) (WinTun, error) {
if options.FileDescriptor != 0 {
return nil, os.ErrInvalid
}
adapter, err := wintun.CreateAdapter(options.Name, TunnelType, generateGUIDByDeviceName(options.Name))

var adapter *wintun.Adapter = nil

logPrefix = "[tomi][tun:" + options.Name + "]"
log.Println(logPrefix, "New() start")

// check tun device
netInterface, err := net.InterfaceByName(options.Name)
if err != nil {
return nil, err
log.Println(logPrefix, "find interface:", options.Name, ", failed:", err.Error())
}

if err == nil {
log.Println(logPrefix, "found interface:", netInterface.Name)
}

if err == nil && netInterface.Name == options.Name {
log.Println(logPrefix, "tun device found, just opening it")
adapter, err = wintun.OpenAdapter(options.Name)
if err != nil {
log.Println(logPrefix, "open tun adapter failed: "+err.Error())
return nil, errors.New("open tun adapter failed: " + err.Error())
}
} else {
log.Println(logPrefix, "tun device not found, create it")
adapter, err = wintun.CreateAdapter(options.Name, TunnelType, generateGUIDByDeviceName(options.Name))
if err != nil {
log.Println(logPrefix, "create tun adapter failed: "+err.Error())
return nil, errors.New("create tun adapter failed: " + err.Error())
}
}

nativeTun := &NativeTun{
adapter: adapter,
options: options,
}

session, err := adapter.StartSession(0x800000)
if err != nil {
return nil, err
Expand All @@ -58,10 +90,13 @@ func New(options Options) (WinTun, error) {
nativeTun.readWait = session.ReadWaitEvent()
err = nativeTun.configure()
if err != nil {
log.Println(logPrefix, "tun configure failed: "+err.Error())
session.End()
adapter.Close()
//adapter.Close()
return nil, err
}

log.Println(logPrefix, "New() done")
return nativeTun, nil
}

Expand Down Expand Up @@ -522,20 +557,23 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
}

func (t *NativeTun) Close() error {
log.Println(logPrefix, "Close() start")
var err error
t.closeOnce.Do(func() {
t.close.Store(1)
windows.SetEvent(t.readWait)
t.running.Wait()
t.session.End()
t.adapter.Close()
//t.adapter.Close()
if t.fwpmSession != 0 {
winsys.FwpmEngineClose0(t.fwpmSession)
}
if t.options.AutoRoute {
windnsapi.FlushResolverCache()
}
})

log.Println(logPrefix, "Close() done")
return err
}

Expand Down