From 08f3d2f88fe8258ee336f4120ca352c3db8498b3 Mon Sep 17 00:00:00 2001 From: Smirnov Olexander Date: Mon, 30 May 2022 15:51:25 +0300 Subject: [PATCH] feat: add treesitter utils --- lua/gopher/_utils/ts/init.lua | 34 +++++++++ lua/gopher/_utils/ts/nodes.lua | 136 +++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+) create mode 100644 lua/gopher/_utils/ts/init.lua create mode 100644 lua/gopher/_utils/ts/nodes.lua diff --git a/lua/gopher/_utils/ts/init.lua b/lua/gopher/_utils/ts/init.lua new file mode 100644 index 0000000..3fbd1ff --- /dev/null +++ b/lua/gopher/_utils/ts/init.lua @@ -0,0 +1,34 @@ +local nodes = require "gopher._utils.ts.nodes" +local M = { + querys = { + struct_block = [[((type_declaration (type_spec name:(type_identifier) @struct.name type: (struct_type)))@struct.declaration)]], + em_struct_block = [[(field_declaration name:(field_identifier)@struct.name type: (struct_type)) @struct.declaration]], + }, +} + +---@return table +local function get_name_defaults() + return { + ["func"] = "function", + ["if"] = "if", + ["else"] = "else", + ["for"] = "for", + } +end + +---@param row any +---@param col any +---@param bufnr any +---@return table|nil +function M.get_struct_node_at_pos(row, col, bufnr) + local query = M.querys.struct_block .. " " .. M.querys.em_struct_block + local bufn = bufnr or vim.api.nvim_get_current_buf() + local ns = nodes.nodes_at_cursor(query, get_name_defaults(), bufn, row, col) + if ns == nil then + print "struct not found" + else + return ns[#ns] + end +end + +return M diff --git a/lua/gopher/_utils/ts/nodes.lua b/lua/gopher/_utils/ts/nodes.lua new file mode 100644 index 0000000..d3b6f9a --- /dev/null +++ b/lua/gopher/_utils/ts/nodes.lua @@ -0,0 +1,136 @@ +local ts_query = require "nvim-treesitter.query" +local parsers = require "nvim-treesitter.parsers" +local locals = require "nvim-treesitter.locals" +local M = {} + +local function intersects(row, col, sRow, sCol, eRow, eCol) + if sRow > row or eRow < row then + return false + end + + if sRow == row and sCol > col then + return false + end + + if eRow == row and eCol < col then + return false + end + + return true +end + +---@param nodes table +---@param row string +---@param col string +---@return table +function M.intersect_nodes(nodes, row, col) + local found = {} + for idx = 1, #nodes do + local node = nodes[idx] + local sRow = node.dim.s.r + local sCol = node.dim.s.c + local eRow = node.dim.e.r + local eCol = node.dim.e.c + + if intersects(row, col, sRow, sCol, eRow, eCol) then + table.insert(found, node) + end + end + + return found +end + +---@param nodes table +---@return table +function M.sort_nodes(nodes) + table.sort(nodes, function(a, b) + return M.count_parents(a) < M.count_parents(b) + end) + + return nodes +end + +---@param query string +---@param lang string +---@param bufnr integer +---@param pos_row string +---@return string +function M.get_all_nodes(query, lang, _, bufnr, pos_row, _) + bufnr = bufnr or 0 + pos_row = pos_row or 30000 + + local ok, parsed_query = pcall(function() + return vim.treesitter.parse_query(lang, query) + end) + if not ok then + return nil + end + + local parser = parsers.get_parser(bufnr, lang) + local root = parser:parse()[1]:root() + local start_row, _, end_row, _ = root:range() + local results = {} + + for match in ts_query.iter_prepared_matches(parsed_query, root, bufnr, start_row, end_row) do + local sRow, sCol, eRow, eCol, declaration_node + local type, name, op = "", "", "" + locals.recurse_local_nodes(match, function(_, node, path) + local idx = string.find(path, ".[^.]*$") + op = string.sub(path, idx + 1, #path) + type = string.sub(path, 1, idx - 1) + + if op == "name" then + name = vim.treesitter.query.get_node_text(node, bufnr) + elseif op == "declaration" or op == "clause" then + declaration_node = node + sRow, sCol, eRow, eCol = node:range() + sRow = sRow + 1 + eRow = eRow + 1 + sCol = sCol + 1 + eCol = eCol + 1 + end + end) + + if declaration_node ~= nil then + table.insert(results, { + declaring_node = declaration_node, + dim = { s = { r = sRow, c = sCol }, e = { r = eRow, c = eCol } }, + name = name, + operator = op, + type = type, + }) + end + end + + return results +end + +---@param query string +---@param default string +---@param bufnr string +---@param row string +---@param col string +---@return table +function M.nodes_at_cursor(query, default, bufnr, row, col) + bufnr = bufnr or vim.api.nvim_get_current_buf() + local ft = vim.api.nvim_buf_get_option(bufnr, "ft") + if row == nil or col == nil then + row, col = unpack(vim.api.nvim_win_get_cursor(0)) + end + + local nodes = M.get_all_nodes(query, ft, default, bufnr, row, col) + if nodes == nil then + print "Unable to find any nodes. place your cursor on a go symbol and try again" + return nil + end + + nodes = M.sort_nodes(M.intersect_nodes(nodes, row, col)) + if nodes == nil or #nodes == 0 then + print("Unable to find any nodes at pos. " .. tostring(row) .. ":" .. tostring(col)) + return nil + end + + return nodes +end + +return M