diff --git a/lua/gopher/_utils/ts/init.lua b/lua/gopher/_utils/ts/init.lua index 90fc899..36b6b4a 100644 --- a/lua/gopher/_utils/ts/init.lua +++ b/lua/gopher/_utils/ts/init.lua @@ -1,11 +1,14 @@ local ts = { queries = { struct = [[ -(type_spec name: (type_identifier) @_name - type: (struct_type)) + (type_spec name: (type_identifier) @_name + type: (struct_type)) + ]], + func = [[ + [(function_declaration name: (identifier) @_name) + (method_declaration name: (field_identifier) @_name)] ]], - -- struct_block = [[((type_declaration (type_spec name:(type_identifier) @struct.name type: (struct_type)))@struct.declaration)]], -- package = [[(package_clause (package_identifier)@package.name)@package.clause]], -- interface = [[((type_declaration (type_spec name:(type_identifier) @interface.name type:(interface_type)))@interface.declaration)]], -- method_name = [[((method_declaration receiver: (parameter_list)@method.receiver name: (field_identifier)@method.name body:(block))@method.declaration)]], @@ -13,15 +16,21 @@ local ts = { }, } ----@param parent_type string +---@param parent_type string|[string] ---@param node TSNode ---@return TSNode? local function get_parrent_node(parent_type, node) ---@type TSNode? local current = node while current do - if current:type() == parent_type then - break + if type(parent_type) == "string" then + if current:type() == parent_type then + break + end + elseif type(parent_type) == "table" then + if vim.tbl_contains(parent_type, current:type()) then + break + end end current = current:parent() @@ -32,30 +41,13 @@ local function get_parrent_node(parent_type, node) return current end ----@param bufnr string ----@return table|nil -function ts.get_struct_node_at_pos(bufnr) - vim.validate { - bufnr = { bufnr, "number" }, - } - - local node = vim.treesitter.get_node() - if not node then - error "No nodes found under cursor" - end - +---@param query vim.treesitter.Query +---@param node TSNode +---@param bufnr integer +---@return {name:string} +local function query_and_get_captures(query, node, bufnr) local res = {} - local r = get_parrent_node("type_spec", node) - if not r then - error "No struct found under cursor" - end - - local start_row, _, end_row, _ = r:range() - res["start_line"] = start_row + 1 - res["end_line"] = end_row + 1 - - local query = vim.treesitter.query.parse("go", ts.queries.struct) - for _, match, _ in query:iter_matches(r, bufnr) do + for _, match, _ in query:iter_matches(node, bufnr) do for capture_id, captured_node in pairs(match) do local capture_name = query.captures[capture_id] if capture_name == "_name" then @@ -67,4 +59,49 @@ function ts.get_struct_node_at_pos(bufnr) return res end +---@param bufnr integer +---@return table +function ts.get_struct_under_cursor(bufnr) + local node = vim.treesitter.get_node() + if not node then + error "No nodes found under cursor" + end + + local parent_node = get_parrent_node("type_spec", node) + if not parent_node then + error "No struct found under cursor" + end + + local query = vim.treesitter.query.parse("go", ts.queries.struct) + local res = query_and_get_captures(query, parent_node, bufnr) + + local start_row, _, end_row, _ = parent_node:range() + res["start_line"] = start_row + 1 + res["end_line"] = end_row + 1 + + return res +end + +---@param bufnr integer +function ts.get_func_under_cursor(bufnr) + local node = vim.treesitter.get_node() + if not node then + error "No nodes found under cursor" + end + + local parent_node = get_parrent_node({ "function_declaration", "method_declaration" }, node) + if not parent_node then + error "No struct found under cursor" + end + + local query = vim.treesitter.query.parse("go", ts.queries.func) + local res = query_and_get_captures(query, parent_node, bufnr) + + local start_row, _, end_row, _ = parent_node:range() + res["start_line"] = start_row + 1 + res["end_line"] = end_row + 1 + + return res +end + return ts diff --git a/lua/gopher/comment.lua b/lua/gopher/comment.lua index 0bb1268..9c1f399 100644 --- a/lua/gopher/comment.lua +++ b/lua/gopher/comment.lua @@ -15,7 +15,7 @@ local function generate(row, col) return comment, ns end - ns = ts_utils.get_struct_node_at_pos(row, col, nil) + ns = ts_utils.get_struct_under_cursor(row, col, nil) if ns ~= nil then comment = "// " .. ns.name .. " " .. ns.type .. " " return comment, ns diff --git a/lua/gopher/gotests.lua b/lua/gopher/gotests.lua index e97f26c..deab8b3 100644 --- a/lua/gopher/gotests.lua +++ b/lua/gopher/gotests.lua @@ -77,13 +77,10 @@ end -- generate unit test for one function function gotests.func_test() - local ns = ts_utils.get_func_method_node_at_pos(unpack(vim.api.nvim_win_get_cursor(0))) - if ns == nil or ns.name == nil then - u.notify("cursor on func/method and execute the command again", vim.log.levels.WARN) - return - end + local bufnr = vim.api.nvim_get_current_buf() + local func = ts_utils.get_func_under_cursor(bufnr) - add_test { "-only", ns.name } + add_test { "-only", func.name } end -- generate unit tests for all functions in current file diff --git a/lua/gopher/impl.lua b/lua/gopher/impl.lua index 0ed3e76..673a119 100644 --- a/lua/gopher/impl.lua +++ b/lua/gopher/impl.lua @@ -41,7 +41,7 @@ local impl = {} ---@return string ---@private local function get_struct() - local ns = ts_utils.get_struct_node_at_pos(unpack(vim.api.nvim_win_get_cursor(0))) + local ns = ts_utils.get_struct_under_cursor(unpack(vim.api.nvim_win_get_cursor(0))) if ns == nil then u.notify "put cursor on a struct or specify a receiver" return "" diff --git a/lua/gopher/struct_tags.lua b/lua/gopher/struct_tags.lua index 9141983..73c7612 100644 --- a/lua/gopher/struct_tags.lua +++ b/lua/gopher/struct_tags.lua @@ -34,7 +34,7 @@ local struct_tags = {} local function modify(...) local fpath = vim.fn.expand "%" ---@diagnostic disable-line: missing-parameter local bufnr = vim.api.nvim_get_current_buf() - local struct = ts_utils.get_struct_node_at_pos(bufnr) + local struct = ts_utils.get_struct_under_cursor(bufnr) -- set user args for cmd local cmd_args = {}