From 1277e7b74f0afb4cca4b298be7d38ae943fa187b Mon Sep 17 00:00:00 2001 From: Oleksandr Smirnov Date: Tue, 18 Mar 2025 14:47:07 +0200 Subject: [PATCH] refactor(ts_util): start from scratch --- lua/gopher/_utils/ts/init.lua | 129 +++++++++++++---------------- lua/gopher/_utils/ts/nodes.lua | 143 --------------------------------- 2 files changed, 57 insertions(+), 215 deletions(-) delete mode 100644 lua/gopher/_utils/ts/nodes.lua diff --git a/lua/gopher/_utils/ts/init.lua b/lua/gopher/_utils/ts/init.lua index 69f28a3..0d8eca5 100644 --- a/lua/gopher/_utils/ts/init.lua +++ b/lua/gopher/_utils/ts/init.lua @@ -1,88 +1,73 @@ ----@diagnostic disable: param-type-mismatch -local nodes = require "gopher._utils.ts.nodes" -local u = require "gopher._utils" local ts = { queries = { - struct_block = [[((type_declaration (type_spec name:(type_identifier) @struct.name type: (struct_type)))@struct.declaration)]], - em_struct_block = [[(field_declaration name:(field_identifier)@struct.name type: (struct_type)) @struct.declaration]], - package = [[(package_clause (package_identifier)@package.name)@package.clause]], - interface = [[((type_declaration (type_spec name:(type_identifier) @interface.name type:(interface_type)))@interface.declaration)]], - method_name = [[((method_declaration receiver: (parameter_list)@method.receiver name: (field_identifier)@method.name body:(block))@method.declaration)]], - func = [[((function_declaration name: (identifier)@function.name) @function.declaration)]], + struct = [[ +(type_spec name: (type_identifier) @name + type: (struct_type)) + ]], + + -- struct_block = [[((type_declaration (type_spec name:(type_identifier) @struct.name type: (struct_type)))@struct.declaration)]], + -- package = [[(package_clause (package_identifier)@package.name)@package.clause]], + -- interface = [[((type_declaration (type_spec name:(type_identifier) @interface.name type:(interface_type)))@interface.declaration)]], + -- method_name = [[((method_declaration receiver: (parameter_list)@method.receiver name: (field_identifier)@method.name body:(block))@method.declaration)]], + -- func = [[((function_declaration name: (identifier)@function.name) @function.declaration)]], }, } ----@return table -local function get_name_defaults() - return { - ["func"] = "function", - ["if"] = "if", - ["else"] = "else", - ["for"] = "for", +---@param parent_type string +---@param node TSNode +---@return TSNode? +local function get_parrent_node(parent_type, node) + ---@type TSNode? + local current = node + while current do + if current:type() == parent_type then + break + end + + current = current:parent() + if current == nil then + return nil + end + end + return current +end + +---@param bufnr string +---@return table|nil +function ts.get_struct_node_at_pos(bufnr) + vim.validate { + bufnr = { bufnr, "number" }, } -end ----@param row string ----@param col string ----@param bufnr string|nil ----@return table|nil -function ts.get_struct_node_at_pos(row, col, bufnr) - local query = ts.queries.struct_block .. " " .. ts.queries.em_struct_block - local bufn = bufnr or vim.api.nvim_get_current_buf() - local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn, row, col) - if ns == nil then - u.deferred_notify("struct not found", vim.log.levels.WARN) - return + local node = vim.treesitter.get_node() + if not node then + error "No nodes found under cursor" end - return ns[#ns] -end ----@param row string ----@param col string ----@param bufnr string|nil ----@return table|nil -function ts.get_func_method_node_at_pos(row, col, bufnr) - local query = ts.queries.func .. " " .. ts.queries.method_name - local bufn = bufnr or vim.api.nvim_get_current_buf() - local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn, row, col) - if ns == nil then - u.deferred_notify("function not found", vim.log.levels.WARN) - return + local res = {} + local r = get_parrent_node("type_declaration", node) + if not r then + error "No struct found under cursor" end - return ns[#ns] -end ----@param row string ----@param col string ----@param bufnr string|nil ----@return table|nil -function ts.get_package_node_at_pos(row, col, bufnr) - if row > 10 then - return - end - local query = ts.queries.package - local bufn = bufnr or vim.api.nvim_get_current_buf() - local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn, row, col) - if ns == nil then - u.deferred_notify("package not found", vim.log.levels.WARN) - return - end - return ns[#ns] -end + local start_row, _, end_row, _ = r:range() + res["start_line"] = start_row + 1 + res["end_line"] = end_row + 1 ----@param row string ----@param col string ----@param bufnr string|nil ----@return table|nil -function ts.get_interface_node_at_pos(row, col, bufnr) - local query = ts.queries.interface - local bufn = bufnr or vim.api.nvim_get_current_buf() - local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn, row, col) - if ns == nil then - u.deferred_notify("interface not found", vim.log.levels.WARN) - return + local query = vim.treesitter.query.parse("go", ts.queries.struct) + + for _, match, _ in query:iter_matches(r, bufnr) do + for capture_id, captured_node in pairs(match) do + local capture_name = query.captures[capture_id] + local text = vim.treesitter.get_node_text(captured_node, bufnr) + + if capture_name == "name" then + res["name"] = text + end + end end - return ns[#ns] + + return res end return ts diff --git a/lua/gopher/_utils/ts/nodes.lua b/lua/gopher/_utils/ts/nodes.lua deleted file mode 100644 index e9c730d..0000000 --- a/lua/gopher/_utils/ts/nodes.lua +++ /dev/null @@ -1,143 +0,0 @@ -local ts_query = require "nvim-treesitter.query" -local parsers = require "nvim-treesitter.parsers" -local locals = require "nvim-treesitter.locals" -local u = require "gopher._utils" -local M = {} - -local function intersects(row, col, sRow, sCol, eRow, eCol) - if sRow > row or eRow < row then - return false - end - - if sRow == row and sCol > col then - return false - end - - if eRow == row and eCol < col then - return false - end - - return true -end - ----@param nodes table ----@param row string ----@param col string ----@return table -function M.intersect_nodes(nodes, row, col) - local found = {} - for idx = 1, #nodes do - local node = nodes[idx] - local sRow = node.dim.s.r - local sCol = node.dim.s.c - local eRow = node.dim.e.r - local eCol = node.dim.e.c - - if intersects(row, col, sRow, sCol, eRow, eCol) then - table.insert(found, node) - end - end - - return found -end - ----@param nodes table ----@return table -function M.sort_nodes(nodes) - table.sort(nodes, function(a, b) - return M.count_parents(a) < M.count_parents(b) - end) - - return nodes -end - ----@param query string ----@param lang string ----@param bufnr integer ----@param pos_row string ----@return string -function M.get_all_nodes(query, lang, _, bufnr, pos_row, _) - bufnr = bufnr or 0 - pos_row = pos_row or 30000 - - local ok, parsed_query = pcall(function() - return vim.treesitter.query.parse(lang, query) - end) - if not ok then - return nil - end - - local parser = parsers.get_parser(bufnr, lang) - local root = parser:parse()[1]:root() - local start_row, _, end_row, _ = root:range() - local results = {} - - for match in ts_query.iter_prepared_matches(parsed_query, root, bufnr, start_row, end_row) do - local sRow, sCol, eRow, eCol, declaration_node - local type, name, op = "", "", "" - locals.recurse_local_nodes(match, function(_, node, path) - local idx = string.find(path, ".[^.]*$") - op = string.sub(path, idx + 1, #path) - type = string.sub(path, 1, idx - 1) - - if op == "name" then - name = vim.treesitter.get_node_text(node, bufnr) - elseif op == "declaration" or op == "clause" then - declaration_node = node - sRow, sCol, eRow, eCol = node:range() - sRow = sRow + 1 - eRow = eRow + 1 - sCol = sCol + 1 - eCol = eCol + 1 - end - end) - - if declaration_node ~= nil then - table.insert(results, { - declaring_node = declaration_node, - dim = { s = { r = sRow, c = sCol }, e = { r = eRow, c = eCol } }, - name = name, - operator = op, - type = type, - }) - end - end - - return results -end - ----@param query string ----@param default string ----@param bufnr string ----@param row string ----@param col string ----@return table -function M.nodes_at_cursor(query, default, bufnr, row, col) - bufnr = bufnr or vim.api.nvim_get_current_buf() - local ft = vim.api.nvim_buf_get_option(bufnr, "ft") - if row == nil or col == nil then - row, col = unpack(vim.api.nvim_win_get_cursor(0)) - end - - local nodes = M.get_all_nodes(query, ft, default, bufnr, row, col) - if nodes == nil then - u.deferred_notify( - "Unable to find any nodes. Place your cursor on a go symbol and try again", - vim.log.levels.DEBUG - ) - return nil - end - - nodes = M.sort_nodes(M.intersect_nodes(nodes, row, col)) - if nodes == nil or #nodes == 0 then - u.deferred_notify( - "Unable to find any nodes at pos. " .. tostring(row) .. ":" .. tostring(col), - vim.log.levels.DEBUG - ) - return nil - end - - return nodes -end - -return M