all repos

gopher.nvim @ 08f3d2f88fe8258ee336f4120ca352c3db8498b3

Minimalistic plugin for Go development

gopher.nvim/lua/gopher/_utils/ts/nodes.lua(view raw)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
local ts_query = require "nvim-treesitter.query"
local parsers = require "nvim-treesitter.parsers"
local locals = require "nvim-treesitter.locals"
local M = {}

local function intersects(row, col, sRow, sCol, eRow, eCol)
  if sRow > row or eRow < row then
    return false
  end

  if sRow == row and sCol > col then
    return false
  end

  if eRow == row and eCol < col then
    return false
  end

  return true
end

---@param nodes table
---@param row string
---@param col string
---@return table
function M.intersect_nodes(nodes, row, col)
  local found = {}
  for idx = 1, #nodes do
    local node = nodes[idx]
    local sRow = node.dim.s.r
    local sCol = node.dim.s.c
    local eRow = node.dim.e.r
    local eCol = node.dim.e.c

    if intersects(row, col, sRow, sCol, eRow, eCol) then
      table.insert(found, node)
    end
  end

  return found
end

---@param nodes table
---@return table
function M.sort_nodes(nodes)
  table.sort(nodes, function(a, b)
    return M.count_parents(a) < M.count_parents(b)
  end)

  return nodes
end

---@param query string
---@param lang string
---@param bufnr integer
---@param pos_row string
---@return string
function M.get_all_nodes(query, lang, _, bufnr, pos_row, _)
  bufnr = bufnr or 0
  pos_row = pos_row or 30000

  local ok, parsed_query = pcall(function()
    return vim.treesitter.parse_query(lang, query)
  end)
  if not ok then
    return nil
  end

  local parser = parsers.get_parser(bufnr, lang)
  local root = parser:parse()[1]:root()
  local start_row, _, end_row, _ = root:range()
  local results = {}

  for match in ts_query.iter_prepared_matches(parsed_query, root, bufnr, start_row, end_row) do
    local sRow, sCol, eRow, eCol, declaration_node
    local type, name, op = "", "", ""
    locals.recurse_local_nodes(match, function(_, node, path)
      local idx = string.find(path, ".[^.]*$")
      op = string.sub(path, idx + 1, #path)
      type = string.sub(path, 1, idx - 1)

      if op == "name" then
        name = vim.treesitter.query.get_node_text(node, bufnr)
      elseif op == "declaration" or op == "clause" then
        declaration_node = node
        sRow, sCol, eRow, eCol = node:range()
        sRow = sRow + 1
        eRow = eRow + 1
        sCol = sCol + 1
        eCol = eCol + 1
      end
    end)

    if declaration_node ~= nil then
      table.insert(results, {
        declaring_node = declaration_node,
        dim = { s = { r = sRow, c = sCol }, e = { r = eRow, c = eCol } },
        name = name,
        operator = op,
        type = type,
      })
    end
  end

  return results
end

---@param query string
---@param default string
---@param bufnr string
---@param row string
---@param col string
---@return table
function M.nodes_at_cursor(query, default, bufnr, row, col)
  bufnr = bufnr or vim.api.nvim_get_current_buf()
  local ft = vim.api.nvim_buf_get_option(bufnr, "ft")
  if row == nil or col == nil then
    row, col = unpack(vim.api.nvim_win_get_cursor(0))
  end

  local nodes = M.get_all_nodes(query, ft, default, bufnr, row, col)
  if nodes == nil then
    print "Unable to find any nodes. place your cursor on a go symbol and try again"
    return nil
  end

  nodes = M.sort_nodes(M.intersect_nodes(nodes, row, col))
  if nodes == nil or #nodes == 0 then
    print("Unable to find any nodes at pos. " .. tostring(row) .. ":" .. tostring(col))
    return nil
  end

  return nodes
end

return M