审计案例–存货

案例背景

我们的案例基于Bibitor有限责任公司,这是一家位于虚构的林肯州的连锁酒类商店。它是一家拥有约80个门店、总销售额超过4.5亿美元的大型零售商。

Bibitor要求团队对其葡萄酒和烈酒业务进行尽职调查,查看12个月期间的期初库存、期末库存、采购和销售数据。

这个案例设计为分三个阶段完成:

  1. 数据准备

我们需要把数据导入并转化为可分析的状态。

  1. 数据探索和可视化

使用在第1阶段创建的表格,将数据加载到软件中进行分析,然后创建和解释可视化图表。

  1. 统计分析

从两个案例研究中分别分析一组自变量/因变量,使用线性回归来提供对数据的一些见解。

第一阶段:数据准备

导入全部数据文件,一共有6个表格,按照以下的要求命名。

‘SalesFINAL12312016’ 命名为 ‘sales_dec’

‘PurchasesFINAL12312016’ 命名为 ‘purchase_dec’

‘InvoicePurchases12312016’ 命名为 ‘vendor_invoices_dec’

‘EndInvFINAL12312016’ 命名为 ‘end_inv_dec’

‘BegInvFINAL12312016’ 命名为 ‘beg_inv_dec’

‘2017PurchasePricesDec’ 命名为 ‘pricing_purchase_dec’

这里数据比较大,除了可以用硬盘直接拷贝以外,还可以远程连接SQL数据库进行下载。

# 链接云端数据库
library(RMySQL)

# 配置连接参数
db_user <- "adastudent1"
db_password <- "SUFEsoa_2024"
db_name <- "inventory"
db_host <- "rm-uf6v72e4quu5fw2dm1o.mysql.rds.aliyuncs.com"
db_port <- 3306  # 通常 MySQL 的默认端口

# 建立连接
con <- dbConnect(MySQL(), user = db_user, password = db_password, dbname = db_name, host = db_host, port = db_port)

连接之后,可以用select语句下载需要的数据。注意,在数据库中表格名称可能和上面不一样,需要和客户核对。

query <- "SELECT * FROM salesdec"
sales_dec <- dbGetQuery(con, query) 

下载所需的数据后,为了有效地利用这些数据,必须对数据进行适当的格式化并测试其完整性。

因此,在读取数据的同时,我们需要仔细阅读R语言给出的信息,判断数据是否存在格式问题。

这里可以发现,数据都是chr类型,因此我们需要做的第一件事就是设置每一列的数据类型。

library(tidyverse)
sales_dec <- sales_dec |> 
  janitor::clean_names()  |> 
  type_convert()

janiotr::clean_names()可以自动将变量名改成蛇形snake命名法。

type_convert() 函数来自于readr包(也是tidyverse核心包之一),适用这个函数的好处是,它可以自动猜测并转换数据类型。因为数据比较大,所以运行起来需要花费一段时间。结合函数运行给出的信息,我们可以再次查看一下数据格式是否是我们需要的。

sales_dec

同样的,我们把其他的表格也下载下来,并做同样的操作。

query <- "SELECT * FROM purchasesdec"
purchase_dec <- dbGetQuery(con, query) |> 
  janitor::clean_names() |> 
  type_convert()
purchase_dec
query <- "SELECT * FROM vendorinvoicesdec"
vendor_invoices_dec <- dbGetQuery(con, query) |> 
  janitor::clean_names() |> 
  type_convert()
vendor_invoices_dec
query <- "SELECT * FROM endinvdec"
end_inv_dec <- dbGetQuery(con, query) |> 
  janitor::clean_names() |> 
  type_convert()
end_inv_dec
query <- "SELECT * FROM beginvdec"
beg_inv_dec <- dbGetQuery(con, query) |> 
  janitor::clean_names() |> 
  type_convert()
beg_inv_dec
query <- "SELECT * FROM pricingpurchasesdec"
pricing_purchase_dec <- dbGetQuery(con, query) |> 
  janitor::clean_names() |> 
  type_convert()
pricing_purchase_dec

如果从硬盘中读取,可以使用下面的代码。这里read_csv函数会自动猜测变量类型,达到type_convert()的作用。

# 从文件中读取数据并清理变量名
library(tidyverse)
sales_dec <- read_csv('case/SalesFINAL12312016.csv') |> 
  janitor::clean_names()

purchase_dec <- read_csv('case/PurchasesFINAL12312016.csv') |> 
  janitor::clean_names()

vendor_invoices_dec <-  read_csv('case/InvoicePurchases12312016.csv')|> 
  janitor::clean_names()

end_inv_dec <-  read_csv('case/EndInvFINAL12312016.csv') |> 
  janitor::clean_names()

beg_inv_dec <-  read_csv('case/BegInvFINAL12312016.csv') |> 
  janitor::clean_names()

pricing_purchase_dec <-  read_csv('case/2017PurchasePricesDec.csv') |> 
  janitor::clean_names()

当然,我们也可以手动调整变量类型,例如:

pricing_purchase_dec |> 
  mutate(classification = factor(classification),
         vendor_number = as.character(vendor_number),
         size = parse_number(size))

parse_number() 函数可以把带字符的数字,例如750ML 转化为750(需要谨慎)。

至此,所有数据已经读取到我们的环境里了,可以关闭和数据库的连接了。

# 关闭连接
dbDisconnect(con)

Case 1

管理层正希望识别和监控供应商活动,从而战略性地将精力集中在关键供应商关系上。

需求:
(1) 创建一个汇总表,包含所有关键供应商(critical vendor)的账单和相关的采购活动


(2) 创建独立的表格来存储关键信息,例如按采购数量排名的”前10大供应商”。所有创建的表格都要命名为”c1_Prep_[表名]”

注意:对”关键供应商”的定义可以基于你对数据的判断和理解。例如,可以将采购金额超过1,000美元的供应商视为关键供应商。

数据清理和检查

首先我们需要对所有的数据做一个基本的检查,比如说是否存在重复值和缺失值。

sum(duplicated(beg_inv_dec)) 
sum(duplicated(end_inv_dec)) 
sum(duplicated(pricing_purchase_dec)) 
sum(duplicated(purchase_dec)) 
sum(duplicated(sales_dec)) 
sum(duplicated(vendor_invoices_dec)) 
#这需要很长一段时间!

这个数据是基于bibbtor公司的各个商店的存货进货的,因此我们可以查看每张表中的商店数量是否一致。

可以发现有一个商店是beg_inv中没有出现的,因此可能是新开的。

any(duplicated(sales_dec))
n_distinct(beg_inv_dec$store) #79
n_distinct(end_inv_dec$store) #80
n_distinct(purchase_dec$store) #80
n_distinct(sales_dec$store) #80

检查重复值

如果想知道某一列是否是唯一的,可以用如下代码来检查:

vendor_invoices_dec |> 
  summarise(n = n(), .by=c(vendor_number,vendor_name)) |> 
  arrange(n)

可以看出vendor_invoices_dec里,vendor是重复出现的,也就是一个vendor会收到一张或者多张发票(当然!)。也可以用以下代码直接检查。

any(duplicated(vendor_invoices_dec$vendor_number))
any(duplicated(vendor_invoices_dec$vendor_name))

因此,我们需要对vendor_invoices_dec做汇总,方便后续分析。

vendors <- vendor_invoices_dec |> 
  group_by(vendor_number,vendor_name) |> 
  summarise(
    n_invoice = n(),
    sum_dollar = sum(dollars),
    sum_quanitity= sum(quantity),
    avg_price = sum(dollars)/sum(quantity),
    avg_price2 = mean(dollars/quantity),
    avg_invoice_day = round(mean(pay_date - invoice_date),2)
  ) |> 
  mutate(across(where(is.numeric), ~ round(.x, 2)))
Caution

这里avg_price和avg_price2两种计算方法有什么区别,哪一种你认为更准确?

通常,第一中方法更为准确,因为他考虑到了交易数量的影响,而第二种方法,在交易数量差异很大的时候很容易失真。

假设你有以下数据:

dollars quantity
100 1
200 10

加权平均为:

avg_price = sum(dollars) / sum(quantity) = 300 / 11 = 27.28

简单平均为:

avg_price2 = (20 + 100 ) / 2 = 60

接下来,我们可以看看其他表中还有哪些有用的信息。

首先是purchase_dec,这个企业的采购清单,主要是每一个存货采购进来存放在哪个仓库store里,是什么品牌brand的,描述description是什么,规格size,从哪一家供应商采购的vendor_number, vendor_name,运输单号po_number,发货日期po_date,收到货物的日期receiving_date,付款日期,pay_date,采购价格purchase_price,数量quantity,总价dollars,分类classification。

然后是pricing_purchase_dec,这是价格清单,主要是每个品牌brand售价price是多少,购入价格purchase_price是多少。

这里面我们可以汇总出的变量有:

  1. 每个vendor运输货物的平均时长
  2. 收到货物之后平均付款时长
  3. 每个vendor的平均价格折让(需要结合两张表的信息)

首先是货物运输时长和付款时长

vendors<- purchase_dec |> 
  group_by(vendor_number, vendor_name) |> 
  summarise(
    n_purchase = n(),
    po_day = round(mean(receiving_date - po_date),2),
    pay_day = round(mean(pay_date - receiving_date),2)
  ) |> 
  right_join(vendors, by =c('vendor_name','vendor_number'))

每个vendor的平均价格折让,我们要先把pricing_purchase_dec的price合并到purchase_dec里面,那么需要一个唯一的合并字段:brand

any(duplicated(pricing_purchase_dec$brand))
any(duplicated(purchase_dec$brand))

pricing表中的brand是唯一的,因此要把purchase_dec作为连接的主表,把pricing左连接过来:

purchase_dec |> 
left_join(pricing_purchase_dec, by = 'brand') 

这样join的效率很低,因为pricing表中所有字段都会被合并过来,可以对pricing做select筛选:

purchase_dec |> 
  left_join(pricing_purchase_dec |> select(brand,price))

可以看到,只有price被合并过来了,这正是我们需要的。合并之后的表,需要再次汇总到vendor层次,来计算每个vendor收到的平均折扣。

vendors<- purchase_dec |> 
  left_join(pricing_purchase_dec |> select(brand,price)) |> 
  group_by(vendor_number,vendor_name) |> 
  summarise(
    avg_discount = sum(price*quantity-dollars)/sum(dollars)
  ) |> 
  right_join(vendors, by = c('vendor_number','vendor_name'))

至此,vendors这张表已经汇总了我们需要的重要信息,如果要创造关键供应商(critical vendor)的账单和相关的采购活动表格,只需对这张表filter即可。

注意:整个过程中我们没有创建额外的中间过程表,并且能够很自然的得到我们想要的结果,这是tidyverse语法的一个重要的优势(相比于其他数据处理语言,如sql、python)

Case 2

你注意到有些采购的库存长期滞留在货架上而不是被售出。

因此,通过识别库存采购时间与其对应销售时间之间的趋势,存在改进采购流程的机会。

在创建用于未来可视化和统计分析的汇总表时,需要考虑采购价格、季节性和供应商信息等关键数据点。

创建的表格命名为’c2_Prep_[表名]’。

purchase_dec |> 
  filter(month(receiving_date)>=12) |> 
  left_join(
    sales_dec |> 
      filter(month(sales_date)>=12),
    by = c('inventory_id','store','brand')
  )

直接合并会产生一个非常大的数据集!

purchase_dec |> 
  left_join(
    sales_dec,
    by = c('inventory_id','store','brand')
  ) 

我们可以先从某一个例开始探索。

head(sales_dec)
sales_xpl <- sales_dec |> 
  filter(
    inventory_id == '1_HARDERSFIELD_1004') |> 
  select(inventory_id,
         store,
         brand,
         sales_quantity,
         sales_dollars,
         sales_price,
         sales_date)

purchase_xpl <- purchase_dec |> 
    filter(
    inventory_id == '1_HARDERSFIELD_1004') |> 
  select(inventory_id,
         store,
         brand,
         purchase_quantity = quantity,
         purchase_dollars = dollars,
         purchase_price,
         receiving_date)

合并

purchase_xpl |> 
  left_join(sales_xpl, by =c('inventory_id','store','brand')) |> 
  filter(receiving_date <= sales_date) |> 
  view()

这里有两个问题,第一个是receiving_date和sales_date的关系,第二个是买入数量和卖出数量的关系。

on_shelft <- purchase_dec |> 
  select(inventory_id,
         store,
         brand,
         purchase_quantity = quantity,
         purchase_dollars = dollars,
         purchase_price,
         receiving_date) |> 
  left_join(  
    sales_dec |> 
  select(inventory_id,
         store,
         brand,
         sales_quantity,
         sales_dollars,
         sales_price,
         sales_date),
  by = c('inventory_id','store','brand')
  ) |> 
  filter(
    receiving_date <= sales_date
  )

可以看到,合并后数据任然大的惊人,一亿多的观察值(如果不filter的话,大概有2亿多)。

接下来,可以把这个数据汇总,因为我们不需要这么细的颗粒度。

基于这个巨大的数据库,我们在探索需要如何汇总的时候,可以不用每次都在原始数据上操作,而是用head()函数,在部分数据上进行实验。

head(on_shelft,10000) |> 
  group_by(inventory_id, store, brand,receiving_date) |> 
  mutate(cum_s_quant = cumsum(sales_quantity)) |> 
  filter(purchase_quantity >= cum_s_quant) |> 
  summarise(
    p_quantity = mean(purchase_quantity),
    s_quantity = sum(sales_quantity),
    days = sum((sales_date - receiving_date)*sales_quantity)/sum(sales_quantity),
    )|>
  view()
on_shelft_days<- on_shelft |> 
  group_by(inventory_id, store, brand,receiving_date) |> 
  mutate(cum_s_quant = cumsum(sales_quantity)) |> 
  filter(purchase_quantity >= cum_s_quant) 
head(on_shelft_days) |> 
  view()
on_shelft_days <- on_shelft_days|> 
  group_by(inventory_id, store, brand,receiving_date) |>
  summarise(
    p_quantity = mean(purchase_quantity),
    s_quantity = sum(sales_quantity),
    days = sum((sales_date - receiving_date)*sales_quantity)/sum(sales_quantity),
    )

这段代码用时太久,而且也可能因为电脑的内存不够而无法运行。

对于这个问题,一个很有效的解决办法就是数据分区,也就是把数据依据某些变量分割成数个更小型的数据。这里我们可以把一年的数据分为12个月,从12月开始处理。

sales_12 <- sales_dec |> 
  filter(as.numeric(month(sales_date))==12)
purchase_12 <- purchase_dec |> 
  filter(as.numeric(month(receiving_date))==12)
on_shelf_12 <- purchase_12 |> 
  select(inventory_id,
         store,
         brand,
         purchase_quantity = quantity,
         purchase_dollars = dollars,
         purchase_price,
         receiving_date,
         vendor_number,
         vendor_name) |> 
  left_join(  
    sales_12 |> 
  select(inventory_id,
         store,
         brand,
         sales_quantity,
         sales_dollars,
         sales_price,
         sales_date
         ),
  by = c('inventory_id','store','brand')
  ) |> 
  filter(
    receiving_date <= sales_date
  )

这样处理起来对算力的要求就小多了。

on_shelf_12 <- on_shelf_12 |> 
  mutate(
    on_shelf_days = sales_date - receiving_date
  ) |> 
  group_by(inventory_id, store, brand,receiving_date) |> 
  summarise(
    avg_on_shelf_days = mean(on_shelf_days),
    max_on_shelf_days = max(on_shelf_days)
  )

第二阶段:数据探索

使用第1阶段的数据库,创建一个包含以下内容的表格:

对于end_inv_dec表中的每条唯一记录(按’品牌Brand’(例如产品SKU)和’门店Store’划分),假设使用移动平均成本法重新计算采购价格。

假设sales_dec表中的每笔交易都会减少库存数量,而每笔采购交易都会增加库存数量。

你的结果表应包括:
[A] 一个新列,显示基于实际交易使用移动平均成本法重新计算的采购价格;
[B] 一个新列,计算end_inv_dec表中的’采购价格purchase_price’与重新计算的列(列[A])之间的差异。

假设beg_inv_dec表中的数量quantities和价格prices是截至2016年12月1日的数据。

提示:销售sales_dec和采购交易表purchases_dec包含12月之前的数据。

1.为了正确核算期初到期末,我们需要合并sales和purchase的数据

transactions <- bind_rows(
  sales_dec |>
  mutate(
    inventory_id = inventory_id,
    store = store,
    brand = brand,
    quantity = sales_quantity *(-1),
    price = sales_price,
    date = sales_date,
    .keep = 'used'
  ),
  purchase_dec |>
    mutate(
      inventory_id = inventory_id,
      store = store,
      brand = brand,
      quantity = quantity,
      price = purchase_price,
      date = receiving_date,
      .keep = 'used'
      )
) |>
  arrange(inventory_id,store, brand,date)  # 按 Store, Brand, Date 排序

把所有交易数据和期初数据合并

beg_plus_trans <- beg_inv_dec |>
  transmute(
    inventory_id = inventory_id,
    store = store,
    brand = brand,
    quantity = on_hand,
    price = price,
    date = start_date
  )|>
  arrange(inventory_id,store, brand,date) |> 
bind_rows(transactions)

这里面关键问题在于,sales里的数据是销售价格而不是成本,因此我们需要把quantity<0的时候的price替换为成本数据

beg_plus_trans <- beg_plus_trans |> 
  arrange(inventory_id,store,brand,date) |> 
  mutate(
    price = if_else(quantity >0, price, NA_real_) # 将 Quantity < 0 的 Price 替换为NA
  ) |> 
  fill(price, .direction = "down") |> # 使用向下填充的方法,将 NA 替换为上一行非 NA 值
  select(inventory_id,store,brand,date,quantity,price)

再去计算移动加权平均

moving_avg_cost <- beg_plus_trans |>
  group_by(inventory_id,store, brand) |>
  mutate(
    cumulative_quantity = cumsum(quantity),  # 累积库存数量
    cumulative_cost = cumsum(quantity * price),  # 累积成本
    avg_cost = if_else(cumulative_quantity > 0, cumulative_cost / cumulative_quantity,0)  # 平均成本
  ) 

检查一下这个数据,那么接下来只需要提取最后一个avg_cost

和期末数据对比

result <- end_inv_dec |>
  rename(date = end_date)|> 
  left_join(moving_avg_cost |> 
              select(inventory_id,store, brand,date, avg_cost),
            by = c("inventory_id", "store", "brand","date")) |>
  mutate(
    cost_diff = round((price - avg_cost),2)  # 计算差异
  )

会发现出现大量的NA,大约有15万行,这是一个非常大的比例,我们可以找一个例子来追溯。

result |> 
  filter(is.na(cost_diff))
moving_avg_cost |> filter(
  inventory_id == '1_HARDERSFIELD_62'
) |> 
  view()

这里的错误在于,12月30是最后一笔卖出记录,但是end_inv里面最后的日期是12月31日,因此应该把日期最晚的记录提取出来,再匹配就行。

final_avg_cost <- moving_avg_cost |> 
  arrange(inventory_id,brand,store,date) |> 
  group_by(inventory_id, brand,store) |> 
  slice_tail(n=1) #切出每组最后一个

arrange():

  • 需要明确按 inventory_idbrandstore 和时间列(date)排序,以确保我们提取的是按时间顺序的最后一个 moving_avg_cost 值。

  • 排序会确保分组后的 slice_tail() 函数能够提取到时间序列中的最后一个值。

result <- end_inv_dec |>
  left_join(final_avg_cost |> 
              select(inventory_id,store, brand,date, avg_cost),
            by = c("inventory_id", "store", "brand")) |> #不再需要date
  mutate(
    cost_diff = round((price - avg_cost),2)  # 计算差异
  )

检查一下NA,仍然有大约1000条,相比较22万行的总数来说,已经是一个可以接受的数字了,而且可以发现,这些都是最后on_hand = 0的记录。

result |> 
  filter (is.na(avg_cost))

追溯一下。

moving_avg_cost |> 
  filter(inventory_id == '1_HARDERSFIELD_1149   ')
beg_plus_trans |> 
  filter(inventory_id=='1_HARDERSFIELD_1149 ')
beg_inv_dec |> 
  filter(inventory_id=='1_HARDERSFIELD_1149 ')

诡异的事情出现了,这个存货似乎是凭空冒出的,那么其他NA也是一样如此吗?

result |> 
  filter (is.na(avg_cost)) |> 
  left_join(beg_plus_trans,
            by = c('inventory_id', "store",'brand')) |> 
  filter(!is.na(price.y))

确实从开始到结束,都没有这个记录。

这些”幽灵“记录不会影响存货的会计计量,但是可能反映了潜在的其他问题。

对于核对的结果,我们可以按照差异大小排列并分组。

result <- result |> 
  mutate(
    diff_ratio = abs(cost_diff)/price
  ) |> 
  mutate(
    diff_group = case_when(
      diff_ratio >= 10 ~ '10+',
      diff_ratio >= 5 & diff_ratio < 10 ~ '5 to 10',
      diff_ratio >= 1 & diff_ratio < 5 ~ '1 to 5',
      diff_ratio >= 0.5 & diff_ratio < 1 ~ '0.5 to 1',
      diff_ratio >= 0.1 & diff_ratio < 0.5 ~ '0.1 to 0.5',
      TRUE ~ '0 to 0.1'
      )
  )
result |> 
  group_by(diff_group) |> 
  summarise(
    n = n(),
    mean_ratio = mean(diff_ratio, na.rm = TRUE),
    mean_diff = mean(cost_diff,na.rm = TRUE),
    mean_cost = mean(price)
  )

用一次加权平均法,计算6月30日的期末存货成本。

第三阶段 统计分析

OLS回归

OLS回归通过一系列的自变量来预测连续性的因变量。

如果我们需要预测运费freight,那么需要找到一系列可能影响运费的因素。

在R中,拟合线性模型最基本的函数就是lm(),格式为

myfit <- lm(formula, data)

其中,formula是要拟合的模型公式,data是一个数据框,拟合公式形式如下:

Y ~ X1 + X2 + ... + Xk

简单线性回归,summary() 是来展示拟合模型的详细结果。

fit <- lm(freight ~ quantity, data = vendor_invoices_dec)
summary(fit)
fit <- lm(freight ~ quantity * approval, data = vendor_invoices_dec)
summary(fit)
fit <- lm(freight ~ log(quantity) * approval, data = vendor_invoices_dec)
summary(fit)

通过plot函数来画出拟合图像。

plot(vendor_invoices_dec$freight,vendor_invoices_dec$quantity,
     xlab = 'fright',
     ylab = 'quantity')
abline(fit)

用ggplot画出来会更加美观。

ggplot(vendor_invoices_dec, aes(x = log(quantity), y = freight)) +
  geom_point(color = "blue") +  # 绘制散点图
  geom_smooth(method = "lm", se = FALSE, color = "red") +  # 添加回归线
  labs(x = "Quantity", y = "Freight", title = "Scatterplot with Regression Line") +
  theme_minimal()  # 使用简洁主题

如果需要将拟合模型的详细结果保存为数据框,可以使用 broom 包。例如:

library(broom)

# 将模型系数转为数据框
tidy_fit <- tidy(fit)  # 整洁数据框形式的系数
print(tidy_fit)

# 提取模型摘要信息(如 R² 和 F 统计量)
glance_fit <- glance(fit)
print(glance_fit)

lm()函数是baseR的自带函数,tidyverse里的tidymodels提供了更为全面的模型选择。

library(tidymodels)  # 用于加载 parsnip 包,以及 tidymodels 生态系统中的其他工具

假设我们想要在模型中加入更多的变量,比如说approval,这是一个字符型变量,因此最好先把他转化为factor变量。

summary(as.factor(vendor_invoices_dec$approval))
ggplot(vendor_invoices_dec,
       aes(x = log(quantity),
           y = freight,
           group = as.factor(approval),
           col = as.factor(approval)))+
         geom_point()+
         geom_smooth(method = lm, se = FALSE)+
         scale_color_viridis_d(option = "plasma", end = .7)

       #`geom_smooth()` 默认的formula是 'y ~ x'

可以看出Frank审核的基本都是数量高,运费贵的订单,其他的可能是系统自动审核的。

对于这种类型的模型,普通最小二乘法(Ordinary Least Squares, OLS)是一个不错的初始方法。使用 tidymodels 时,我们首先通过 parsnip 包指定我们想要的模型的函数形式。由于结果是数值型的,并且模型应该是具有斜率和截距的线性模型,因此模型类型为“线性回归”。我们可以这样声明:

linear_reg()

这看起来有些平淡,因为就目前而言,它本身并没有真正做什么。然而,现在我们已经指定了模型的类型,就可以考虑拟合或训练模型的方法,也就是模型引擎。引擎的值通常是用于拟合或训练模型的软件和估计方法的组合。正如上面所看到的,linear_reg() 的默认引擎是普通最小二乘法(“lm”)。

我们也可以选择一个非默认选项,例如

linear_reg() |> 
  set_engine("keras")

linear_reg()文档页面列出了所有可能的引擎选项。我们将使用默认引擎(“lm”)保存我们的模型对象为 lm_mod

lm_mod <- linear_reg()

从这里,我们可以使用 fit() 函数来估计或训练模型:(这里我们并没有把approval转化为factor,因为approval只有2个取值,会自动转化为factor处理)。

fit(freight ~ quantity * approval, data = vendor_invoices_dec)

lm_fit <- 
  lm_mod |>  
  fit(freight ~ quantity * approval, data = vendor_invoices_dec)
lm_fit

lm_fit 解释:

此代码的含义是:

  1. lm_mod:定义了一个线性回归模型,使用了默认的 “lm” 引擎。

  2. fit():将线性回归模型 lm_mod 拟合到数据上。

训练完成后,lm_fit 保存了拟合的线性回归模型。

也许我们的分析需要对模型参数估计值及其统计属性进行描述。虽然 lm 对象的 summary() 函数可以提供这些信息,但它返回的结果格式难以处理且不够直观。许多模型都有一个 tidy() 方法,它可以以更可预测和更有用的格式(例如带有标准列名的数据框)提供摘要结果:

tidy(lm_fit) |> 
  view()
tidy(lm_fit) %>% 
  dwplot(dot_args = list(size = 2, color = "black"),
         whisker_args = list(color = "black"),
         vline = geom_vline(xintercept = 0, colour = "grey50", linetype = 2))

使用模型来预测

这个拟合对象 lm_fit 包含了内置的 lm 模型输出,你可以通过 lm_fit$fit 访问它。

为了获得预测结果,我们可以使用 predict() 函数来计算 quantity = 5000 情况下运费的平均值。

首先,我们需要创建新的观测值:

new_points <- expand.grid(quantity = 5000, 
                          approval = c("None", "Frank Delahunt"))
new_points

如果我们直接使用 lm() 拟合模型,花几分钟阅读 predict.lm() 的文档可以帮助我们了解如何实现这一点。然而,如果我们决定使用一种不同的模型来估计运费的大小,很可能需要完全不同的语法。

与其这样,使用 tidymodels 的好处是,预测值的类型被标准化了,这样我们可以使用相同的语法来获取这些值。

首先,让我们生成new_points下运费的预测值:

freight_pred <- predict(lm_fit, new_data = new_points)
freight_pred

在进行预测时,tidymodels 的约定是始终生成一个带有标准化列名的 tibble 结果。这种格式使得将原始数据与预测结果结合在一起变得简单且易于使用:

# 带有置信区间的预测
conf_int_pred <- predict(lm_fit, 
                         new_data = new_points, 
                         type = "conf_int")
conf_int_pred

# 结合
plot_data <- 
  new_points |> 
  bind_cols(freight_pred) |> 
  bind_cols(conf_int_pred)

# 画图
ggplot(plot_data, aes(x = approval)) + 
  geom_point(aes(y = .pred)) + 
  geom_errorbar(aes(ymin = .pred_lower, 
                    ymax = .pred_upper),
                width = .2) + 
  labs(y = "predicted freight")

使用像 linear_reg() 这样的函数来定义模型可能看起来有些多余,因为直接调用 lm() 会更加简洁。然而,标准的建模函数存在一个问题,它们并未将你想做的事情执行过程分开。例如,即使公式没有变化,执行公式的过程也必须在每次模型调用中重复进行;我们无法重复利用这些计算结果。

此外,使用 tidymodels 框架,我们可以通过逐步创建模型来做一些有趣的事情(而不是通过单个函数调用完成)。例如,tidymodels 的模型调优(model tuning)功能使用模型规范来声明哪些部分需要被调优。如果 linear_reg() 直接拟合了模型,那么实现这一点将变得非常困难。

模型评估

模型残差是诊断线性回归模型的核心。可以通过 augment() 函数提取拟合值和残差:

diagnostics <- augment(lm_fit,vendor_invoices_dec)
head(diagnostics)

预测和实际值比较

ggplot(diagnostics, aes(x = .pred, y = freight)) +
  geom_point() +
  geom_hline(yintercept = 0, color = "red", linetype = "dashed") +
  labs(title = "Fitted vs. Actual",
       x = "Fitted Values",
       y = "Actual") +
  theme_minimal()

预测值和残差

ggplot(diagnostics, aes(x = .pred, y = .resid)) +
  geom_point() +
  geom_hline(yintercept = 0, color = "red", linetype = "dashed") +
  labs(title = "Residuals vs Fitted",
       x = "Fitted Values",
       y = "Residuals") +
  theme_minimal()
  • 如果残差随机分布且没有明显模式,模型的线性假设成立。

  • 如果存在模式(如曲线趋势),说明可能需要引入非线性项。

残差的正态性检查

使用 QQ 图检查残差是否符合正态分布。

ggplot(diagnostics, aes(sample = .resid)) +
  stat_qq() +
  stat_qq_line(color = "red") +
  labs(title = "QQ Plot of Residuals",
       x = "Theoretical Quantiles",
       y = "Sample Quantiles") +
  theme_minimal()
  • 如果点沿参考线分布,说明残差接近正态分布。
  • 如果偏离参考线,可能需要检查数据的异常值或分布特性。